initial commit
authorOlivier Dion <odion@efficios.com>
Wed, 13 Mar 2024 19:43:20 +0000 (15:43 -0400)
committerOlivier Dion <odion@efficios.com>
Wed, 13 Mar 2024 19:43:20 +0000 (15:43 -0400)
Signed-off-by: Olivier Dion <odion@efficios.com>
Makefile [new file with mode: 0644]
README.md [new file with mode: 0644]
check [new file with mode: 0755]
dev-env [new file with mode: 0755]
lttng-auto-mpi-wrappers [new file with mode: 0755]
lttng-auto-ust-api [new file with mode: 0755]
lttng/ust-context-provider.h [new file with mode: 0644]
run-test-mpi [new file with mode: 0755]
test-mpi.c [new file with mode: 0644]
test.c [new file with mode: 0644]

diff --git a/Makefile b/Makefile
new file mode 100644 (file)
index 0000000..e7b5fdd
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,45 @@
+# SPDX-License-Identifier: MIT
+#
+# Copyright (c) 2023 Olivier Dion <odion@efficios.com>
+
+# Configure me:
+
+## The provider of the events.
+PROVIDER?=mpi
+
+## The MPI header to use.
+MPI_HEADER?=$(shell pkg-config --variable=includedir ompi)/mpi.h
+
+LTTNG_UST_CFLAGS=$(shell pkg-config --cflags lttng-ust)
+LTTNG_UST_LIBS=$(shell pkg-config --libs lttng-ust)
+
+# Do not modify below.
+EXTRA_CLAGS=$(LTTNG_UST_CFLAGS)
+EXTRA_LIBS=$(LTTNG_UST_LIBS) -ldl
+
+AUTOGEN_MPI_API=lttng-ust-mpi-defs.h lttng-ust-mpi.h lttng-ust-mpi-classes.h lttng-ust-mpi-impl.c lttng-ust-mpi-states.h
+
+all: liblttng-ust-mpi.so test-mpi
+
+clean:
+       rm -f $(AUTOGEN_MPI_API)
+       rm -f null.o
+       rm -f liblttng-ust-mpi.so test-mpi lttng-ust-mpi.c
+       rm -rf traces
+
+check: all
+       ./check
+
+test-mpi: test-mpi.c
+       mpicc -o $@ $^
+
+liblttng-ust-mpi.so: lttng-ust-mpi.c $(AUTOGEN_MPI_API)
+       gcc -O2 -D ENABLE_LTTNG_TRACEPOINTS -I . -I $$(dirname $(MPI_HEADER)) -Wall -Werror -shared -fPIC $(EXRTA_CFLAGS) -o $@ lttng-ust-mpi.c lttng-ust-mpi-impl.c $(EXTRA_LIBS)
+
+lttng-ust-mpi.c: $(MPI_HEADER) | lttng-auto-mpi-wrappers
+       ./lttng-auto-mpi-wrappers $(MPI_HEADER) $@
+
+$(AUTOGEN_MPI_API) &: $(MPI_HEADER)
+       ./lttng-auto-ust-api --ignore MPI_Pcontrol --provider $(PROVIDER) --emulated-classes --common-prefix MPI_ $^ $(AUTOGEN_MPI_API)
+
+.PHONY: check
diff --git a/README.md b/README.md
new file mode 100644 (file)
index 0000000..0cc2482
--- /dev/null
+++ b/README.md
@@ -0,0 +1,34 @@
+# lttng-ust-mpi
+
+Auto generation of instrumentation library for MPI with LTTng.
+
+# Requirements
+
+- babeltrace (check program)
+- clang
+- gcc
+- lttng-tools (check program)
+- lttng-ust
+- make
+- openmpi or craympi
+- openssh or srun (check program)
+- pkg-config
+- python-clang
+
+# Build
+
+Call `make` and that is it.  You can pass `MPI_HEADER=path-to-mpi.h`.  This will
+bypass the pkg-config finding of mpi.h.  You can also pass
+`PROVIDER=my_provider` if you want to change the default provider name (mpi).
+
+The resulting `lttng-ust-mpi.so` can be used to instrument OpenMPI or CrayMPI
+with LTTng.  See the `check` script for usage example.
+
+# Set of ignored functions
+
+The only function ignored is `MPI_Pcontrol`.  One can extend the `forbiden_list`
+in `lttng-auto-mpi-wrappers` to ignore more functions.
+
+# Development
+
+If using the `guix` package manager, simply do `./dev-env`.
diff --git a/check b/check
new file mode 100755 (executable)
index 0000000..f557dc1
--- /dev/null
+++ b/check
@@ -0,0 +1,28 @@
+#!/bin/sh
+#
+# SPDX-License-Identifier: MIT
+#
+# Copyright (c) 2023 Olivier Dion <odion@efficios.com>
+
+TRACE_OUTPUT=traces
+
+lttng create --output $TRACE_OUTPUT
+
+# Enable all MPI event types.
+lttng enable-event --userspace 'mpi:*'
+
+# Add MPI rank as a application context to every event.
+lttng add-context --userspace --type '$app.MPI:rank'
+
+lttng start
+
+if command -v mpirun;
+then
+    mpirun --n 4 ./run-test-mpi 100
+else
+    srun -p cray --ntasks=4 ./run-test-mpi 1000
+fi
+
+lttng destroy
+
+babeltrace2 $TRACE_OUTPUT
diff --git a/dev-env b/dev-env
new file mode 100755 (executable)
index 0000000..dac911b
--- /dev/null
+++ b/dev-env
@@ -0,0 +1,25 @@
+#!/bin/sh
+#
+# SPDX-License-Identifier: MIT
+#
+# Copyright (c) 2023 Olivier Dion <odion@efficios.com>
+
+guix shell --pure  \
+     babeltrace    \
+     coreutils     \
+     findutils     \
+     gawk          \
+     gcc-toolchain \
+     clang         \
+     git           \
+     grep          \
+     less          \
+     lttng-tools   \
+     lttng-ust     \
+     make          \
+     openmpi       \
+     openssh       \
+     pkg-config    \
+     python        \
+     python-clang  \
+     -- $@
diff --git a/lttng-auto-mpi-wrappers b/lttng-auto-mpi-wrappers
new file mode 100755 (executable)
index 0000000..d6e7948
--- /dev/null
@@ -0,0 +1,274 @@
+#!/usr/bin/env python3
+#
+# SPDX-License-Identifier: MIT
+#
+# Copyright (c) 2023 EfficiOS, Inc.
+#
+# Author: Olivier Dion <odion@efficios.com>
+#
+# Auto-generate lttng-ust tracepoints for OpenMPI.
+#
+# Require: python-clang (libclang)
+
+import argparse
+import re
+
+from string import Template
+
+import clang.cindex
+
+def list_function_declarations(root):
+    return [ child
+             for child in root.get_children()
+             if child.kind == clang.cindex.CursorKind.FUNCTION_DECL ]
+
+def parse_header(header_file):
+    return clang.cindex.Index.create().parse(header_file).cursor
+
+def list_functions(root):
+    return [
+        fn
+        for fn in list_function_declarations(root)
+        if fn.spelling.startswith("MPI_") and fn.spelling
+    ]
+
+def exact_definition(arg):
+    m = re.search(r'(\[[0-9]*\])+', arg.type.spelling)
+    if m:
+        return f"{arg.type.spelling[:m.start(0)]} {arg.spelling}{m.group(0)}"
+    else:
+        return f"{arg.type.spelling} {arg.spelling}"
+
+forbiden_list = {
+    "MPI_Pcontrol"
+}
+
+extra_works = {
+    "MPI_Init": """
+       if (MPI_SUCCESS == ret) {
+               int (*mpi_comm_rank)(MPI_Comm, int *rank);
+               MPI_Comm mpi_comm_world;
+#ifdef CRAY_MPICH_VERSION
+               mpi_comm_world = MPI_COMM_WORLD;
+#else
+               mpi_comm_world = *(void**)resolve_or_die("ompi_mpi_comm_world_addr");
+#endif
+               mpi_comm_rank  = resolve_or_die("PMPI_Comm_rank");
+               mpi_comm_rank(mpi_comm_world, &mpi_rank);
+               mpi_provider.priv = lttng_ust_context_provider_register(&mpi_provider);
+       }
+""",
+    "MPI_Finalize": """
+       if (mpi_provider.priv) {
+               lttng_ust_context_provider_unregister(mpi_provider.priv);
+       }
+""",
+}
+
+def main():
+
+    parser = argparse.ArgumentParser(prog="lttng-ust-auto-mpi")
+
+    parser.add_argument("api",
+                        help="MPI API header")
+
+    parser.add_argument("wrappers",
+                        help="Path to MPI wrappers")
+
+    args = parser.parse_args()
+
+    fn_tpl = Template("""
+${ret_type} ${fn_name}(${fn_arguments})
+{
+       ${ret_type} ret;
+       {
+               static ${ret_type}(*real_fn)(${fn_arguments}) = NULL;
+               if (unlikely(NULL == __atomic_load_n(&real_fn, __ATOMIC_RELAXED))) {
+                       void *result = resolve_or_die("P${fn_name}");
+                       __atomic_store_n(&real_fn, result, __ATOMIC_RELAXED);
+               }
+               LTTNG_MAKE_API_OBJECT(${fn_name}${fn_rest_argument_names});
+               ret = real_fn(${fn_pass_argument_names});
+               LTTNG_MARK_RETURN_API_OBJECT(ret);
+       }
+$extra_work
+       return ret;
+}
+""")
+
+    with open(args.wrappers, "w") as output:
+        output.write("""/* Auto-generated */
+#define _GNU_SOURCE
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <dlfcn.h>
+
+#include <mpi.h>
+
+#include <lttng/ust-events.h>
+#include <lttng/ust-ringbuffer-context.h>
+
+#include "lttng/ust-context-provider.h"
+
+#include "lttng-ust-mpi-states.h"
+
+#define likely(x)   __builtin_expect(!!(x), 1)
+#define unlikely(x) __builtin_expect(!!(x), 0)
+
+#define die(fmt, ...)                                          \\
+       do {                                                    \\
+               fprintf(stderr, fmt "\\n", ##__VA_ARGS__);      \\
+               exit(EXIT_FAILURE);                             \\
+       } while (0)
+
+static void *resolve_or_die(const char *symbol)
+{
+       void *ret = dlsym(RTLD_NEXT, symbol);
+       if (unlikely(!ret)) {
+               die("could not resolve `%s': %s", symbol, dlerror());
+       }
+       return ret;
+}
+
+static inline int streq(const char *A, const char *B)
+{
+       return 0 == strcmp(A, B);
+}
+
+static inline char *context_type(struct lttng_ust_app_context *uctx)
+{
+       char *suffix = index(uctx->ctx_name, ':');
+
+       if (likely(suffix)) {
+               suffix = &suffix[1]; /* Skip ':' */
+       }
+
+       return suffix;
+}
+
+static int mpi_rank = -1;
+
+static size_t mpi_provider_get_size(void *uctx,
+                                   struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
+                                   size_t offset)
+{
+       size_t size = 0;
+       char *type = context_type(uctx);
+
+       size += lttng_ust_ring_buffer_align(offset, lttng_ust_rb_alignof(char));
+       size += sizeof(char);
+
+       if (unlikely(!type)) {
+               goto error;
+       }
+
+       if (streq(type, "rank")) {
+               size += lttng_ust_ring_buffer_align(offset, lttng_ust_rb_alignof(int64_t));
+               size += sizeof(int64_t);
+
+       } else {
+       error:
+               /* Unknown context. */
+               (void) size;
+       }
+
+       return size;
+}
+
+static void mpi_provider_record(void *uctx,
+                               struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
+                               struct lttng_ust_ring_buffer_ctx *ctx,
+                               struct lttng_ust_channel_buffer *lttng_chan_buf)
+{
+       int sel;
+       char sel_char;
+       char *type = context_type(uctx);
+
+       if (unlikely(!type)) {
+               goto error;
+       }
+
+       if (streq(type, "rank")) {
+               int64_t v;
+               sel      = LTTNG_UST_DYNAMIC_TYPE_S64;
+               sel_char = (char) sel;
+               v        = (int64_t) mpi_rank;
+               lttng_chan_buf->ops->event_write(ctx, &sel_char, sizeof(sel_char),
+                                                lttng_ust_rb_alignof(char));
+
+               lttng_chan_buf->ops->event_write(ctx, &v, sizeof(v), lttng_ust_rb_alignof(v));
+       } else {
+       error:
+               sel      = LTTNG_UST_DYNAMIC_TYPE_NONE;
+               sel_char = (char) sel;
+               lttng_chan_buf->ops->event_write(ctx, &sel_char, sizeof(sel_char),
+                                                lttng_ust_rb_alignof(char));
+       }
+}
+
+static void mpi_provider_get_value(void *uctx,
+                                  struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
+                                  struct lttng_ust_ctx_value *value)
+{
+       char *type = context_type(uctx);
+
+       if (unlikely(!type)) {
+               goto error;
+       }
+
+       if (streq(type, "rank")) {
+               value->sel   = LTTNG_UST_DYNAMIC_TYPE_S64;
+               value->u.s64 = (int64_t) mpi_rank;
+       } else {
+       error:
+               value->sel = LTTNG_UST_DYNAMIC_TYPE_NONE;
+       }
+}
+
+static struct lttng_ust_context_provider mpi_provider = {
+       .struct_size = sizeof(struct lttng_ust_context_provider),
+       .name        = "$app.MPI",
+       .get_size    = mpi_provider_get_size,
+       .record      = mpi_provider_record,
+       .get_value   = mpi_provider_get_value
+};
+""")
+        for fn in list_functions(parse_header(args.api)):
+
+            if fn.spelling in forbiden_list:
+                continue
+
+            args = list(fn.get_arguments())
+            fn_pass_argument_names = ", ".join([
+                f"{arg.spelling}"
+                for arg in args
+            ])
+
+            if args:
+                fn_rest_argument_names = ", " + ", ".join([
+                    "(%s)%s" % (re.sub(r'\[[0-9]*\]', '*', arg.type.spelling),
+                                arg.spelling)
+                    for arg in args
+                ])
+            else:
+                fn_rest_argument_names=""
+
+            if fn.spelling in extra_works:
+                extra_work = extra_works[fn.spelling]
+            else:
+                extra_work = ""
+
+            output.write(fn_tpl.substitute(fn_name=fn.spelling,
+                                           fn_arguments=", ".join([
+                                               exact_definition(arg)
+                                               for arg in fn.get_arguments()
+                                           ]),
+                                           fn_pass_argument_names=fn_pass_argument_names,
+                                           fn_rest_argument_names=fn_rest_argument_names,
+                                           ret_type=fn.type.get_result().spelling,
+                                           extra_work=extra_work))
+
+if __name__ == "__main__":
+    main()
diff --git a/lttng-auto-ust-api b/lttng-auto-ust-api
new file mode 100755 (executable)
index 0000000..e93dfd9
--- /dev/null
@@ -0,0 +1,880 @@
+#!/usr/bin/env python3
+#
+# SPDX-License-Identifier: MIT
+#
+# Copyright (c) 2023 EfficiOS, Inc.
+#
+# Author: Olivier Dion <odion@efficios.com>
+#
+# Auto-generate lttng-ust tracepoints for OpenMPI.
+#
+# Require: python-clang (libclang)
+
+import argparse
+import re
+import os
+import subprocess
+
+from string import Template
+
+import clang.cindex
+
+COMMON_PREFIX = None
+IGNORE        = None
+PROVIDER      = None
+
+# LTTNG_UST_TP_ARGS is limited to 10 arguments.  Since we introduce two
+# arguments of our own (thread-id and local-id), the maximum is 8.
+#
+# If a function has more arguments than this limit, all arguments -- at the
+# exception of the IDs -- will be passed through a data structure instead.
+MAX_TP_ARGS_COUNT = 8
+
+class EnumValue:
+
+    def __init__(self, ev):
+        self.name  = ev.spelling
+        self.value = ev.enum_value
+
+class EnumType:
+
+    def __init__(self, en, name=None):
+        self.name   = name or en.spelling
+        self.values = [EnumValue(ev) for ev in en.get_children()]
+
+class Typedef:
+
+    def __init__(self, spelling, value):
+        self.spelling = spelling
+        self.value    = value
+
+class ArgumentType:
+
+    integer_set = {
+        clang.cindex.TypeKind.UCHAR,
+        clang.cindex.TypeKind.USHORT,
+        clang.cindex.TypeKind.UINT,
+        clang.cindex.TypeKind.ULONG,
+        clang.cindex.TypeKind.ULONGLONG,
+        clang.cindex.TypeKind.SHORT,
+        clang.cindex.TypeKind.INT,
+        clang.cindex.TypeKind.LONG,
+        clang.cindex.TypeKind.LONGLONG,
+    }
+
+    float_set = {
+        clang.cindex.TypeKind.FLOAT,
+        clang.cindex.TypeKind.DOUBLE,
+    }
+
+    address_set = {
+        clang.cindex.TypeKind.POINTER,
+        clang.cindex.TypeKind.INCOMPLETEARRAY,
+    }
+
+    def __init__(self, arg, name_prefix="", expr_prefix=""):
+        self.type = arg.type
+        self.arg  = arg
+        self.const = ""
+        self.name_prefix = name_prefix
+        self.expr_prefix = expr_prefix
+
+        if self.kind() == clang.cindex.TypeKind.POINTER:
+            if self.type.get_pointee().is_const_qualified():
+                self.const = "const "
+        elif self.type.is_const_qualified():
+            self.const = "const "
+
+    def name(self):
+        return self.arg.spelling
+
+    def type_name(self):
+        if self.kind() == clang.cindex.TypeKind.INCOMPLETEARRAY:
+            return self.const + re.sub(r"\[[0-9]*\]", "*", self.type.spelling)
+        if self.kind() == clang.cindex.TypeKind.POINTER:
+            return f"{self.const}void *"
+        return self.const + self.type.spelling
+
+    def kind(self):
+        return self.type.get_canonical().kind
+
+    def to_lttng_field(self):
+        if self.name() == "reserved":
+            return ""
+        elif self.kind() in ArgumentType.address_set:
+            return f"lttng_ust_field_integer_hex(uintptr_t, {self.name_prefix}{self.name()}, (uintptr_t){self.expr_prefix}{self.name()})"
+        elif self.kind() in ArgumentType.integer_set:
+            return f"lttng_ust_field_integer({self.type_name()}, {self.name_prefix}{self.name()}, {self.expr_prefix}{self.name()})"
+        elif self.kind() in ArgumentType.float_set:
+            return f"lttng_ust_field_float({self.type_name()}, {self.name_prefix}{self.name()}, {self.expr_prefix}{self.name()})"
+        elif self.kind() == clang.cindex.TypeKind.ENUM:
+            enum_name = self.type_name().removeprefix("enum ")
+            return f"lttng_ust_field_enum({PROVIDER}, {enum_name}, int, {self.name_prefix}{self.name()}, {self.expr_prefix}{self.name()})"
+        elif self.kind() == clang.cindex.TypeKind.RECORD:
+            return [
+                ArgumentType(field, f"{self.name()}_", f"{self.expr_prefix}{self.name()}.").to_lttng_field()
+                for field in self.type.get_canonical().get_fields()
+            ]
+        else:
+            raise Exception("Unsupported kind: %s" % self.kind())
+
+class FunctionType:
+
+    struct_tpl = Template("""
+$name {
+    $fields
+};
+""")
+
+    def __init__(self, fn):
+        self.name = fn.spelling
+        self.args = [ArgumentType(arg) for arg in fn.get_arguments()]
+        self.fn   = fn
+
+    def tp_args(self):
+        if len(self.args) == 0:
+            return ""
+        elif len(self.args) > MAX_TP_ARGS_COUNT:
+            return ",\n        " + f"{self.arguments_struct_name()} *, lttng_args"
+        else:
+            return ",\n        " + ",\n        ".join([f"{arg.type_name()}, {arg.name()}"
+                                                           for arg in self.args])
+
+    def tp_fields(self):
+        if len(self.args) == 0:
+            return ""
+        elif len(self.args) > MAX_TP_ARGS_COUNT:
+            packed_args = [ArgumentType(arg.arg, "", "lttng_args->") for arg in self.args]
+            return "\n        ".join(flatten([arg.to_lttng_field()
+                                              for arg in packed_args]))
+        else:
+            return "\n        ".join(flatten([arg.to_lttng_field()
+                                              for arg in self.args]))
+    def get_return_type_name(self):
+        return self.fn.type.get_result().spelling
+
+    def ctor_params(self):
+        if len(self.args) == 0:
+            return ""
+        elif len(self.args) > MAX_TP_ARGS_COUNT:
+            return ", &lttng_args"
+        else:
+            return ", " + ", ".join(arg.name() for arg in self.args)
+
+    def arguments_struct_variable(self):
+        if len(self.args) > MAX_TP_ARGS_COUNT:
+            return "%s lttng_args = {%s};" % (self.arguments_struct_name(),
+                                              ", ".join([arg.name() for arg in self.args]))
+        else:
+            return f"/* {self.arguments_struct_name()} lttng_args */"
+
+
+    def arguments_struct_name(self):
+        return f"struct lttng_arguments_of_{self.name}"
+
+    def arguments_struct(self):
+        if len(self.args) > MAX_TP_ARGS_COUNT:
+            return self.struct_tpl.substitute(name=self.arguments_struct_name(),
+                                              fields="\n    ".join([
+                                                  f"{arg.type_name()} {arg.name()};"
+                                                  for arg in self.args
+                                              ]))
+        else:
+            return ""
+
+def flatten(lst):
+    new_lst = []
+    for e in lst:
+        if isinstance(e, list):
+            for e in flatten(e):
+                new_lst.append(e)
+        else:
+            new_lst.append(e)
+    return new_lst
+
+def list_function_declarations(root):
+    return [ child
+             for child in root.get_children()
+             if child.kind == clang.cindex.CursorKind.FUNCTION_DECL ]
+
+def list_enum_declarations(root):
+    return [
+        child
+        for child in root.get_children()
+        if child.kind == clang.cindex.CursorKind.ENUM_DECL
+    ]
+
+def list_typedef_enums(root):
+    enums = []
+    for child in root.get_children():
+        if child.kind == clang.cindex.CursorKind.TYPEDEF_DECL:
+            maybe_enum = child.underlying_typedef_type.get_declaration()
+            if maybe_enum.kind == clang.cindex.CursorKind.ENUM_DECL:
+                enums.append(Typedef(child.spelling, maybe_enum))
+    return enums
+
+def search_header_in(name, paths):
+    for path in paths.split(":"):
+        for dirpath, _, files in os.walk(path, followlinks=True):
+            for file in files:
+                if file == name:
+                    return os.path.join(dirpath, file)
+    return None
+
+def search_c_header(name):
+    return search_header_in(name, os.environ["C_INCLUDE_PATH"])
+
+def search_cxx_header(name):
+    return search_header_in(name, os.environ["CPLUS_INCLUDE_PATH"])
+
+def get_system_include_paths():
+
+    clang_args = ["clang", "-v", "-c", "-xc", "/dev/null"]
+    paths = []
+
+    with subprocess.Popen(clang_args, stderr=subprocess.PIPE) as proc:
+        start_sys_search = False
+        for line in proc.stderr:
+            if start_sys_search:
+                if line == "End of search list.\n":
+                    break
+                paths.append("-isystem")
+                paths.append(line.strip())
+            elif line == "#include <...> search starts here:\n":
+                start_sys_search = True
+
+    return paths
+
+def parse_header(header_file, includes, defines,
+                 required_c_headers, required_cxx_headers):
+
+    args = get_system_include_paths()
+
+    if includes:
+        for inc in includes:
+            args.append("-I")
+            args.append(inc)
+
+    if defines:
+        for d in defines:
+            args.append("-D")
+            args.append(d)
+
+    for header in required_c_headers:
+        found = search_c_header(header)
+        if found:
+            args.append("-I")
+            args.append(os.path.dirname(found))
+
+    for header in required_cxx_headers:
+        found = search_cxx_header(header)
+        if found:
+            args.append("-I")
+            args.append(os.path.dirname(found))
+
+    tu = clang.cindex.Index.create().parse(header_file, args=args)
+
+    for d in tu.diagnostics:
+        print(d)
+
+    return tu.cursor
+
+def list_functions(root):
+    return [
+        FunctionType(fn)
+        for fn in list_function_declarations(root)
+        if fn.spelling.startswith(COMMON_PREFIX) and fn.spelling not in IGNORE
+    ]
+
+def list_enums(root):
+    enums = [
+        EnumType(en)
+        for en in list_enum_declarations(root)
+        if en.spelling.startswith(COMMON_PREFIX) and en.spelling not in IGNORE
+    ]
+
+    typedef_enums = [
+        EnumType(typedef.value, typedef.spelling)
+        for typedef in list_typedef_enums(root)
+        if typedef.spelling.startswith(COMMON_PREFIX) and typedef.spelling not in IGNORE
+    ]
+
+    return enums + typedef_enums
+
+def generate_tracepoint_definitions(function_declarations, enum_declarations,
+                                    api_file, output_defs, output_interface,
+                                    header_guard):
+    defs_tpl = Template("""/* Auto-generated file! */
+#undef LTTNG_UST_TRACEPOINT_PROVIDER
+#define LTTNG_UST_TRACEPOINT_PROVIDER $provider
+
+#undef LTTNG_UST_TRACEPOINT_INCLUDE
+#define LTTNG_UST_TRACEPOINT_INCLUDE "$output_defs"
+
+#if !defined($header_guard)
+#include <$api_file>
+$pass_by_struct
+#endif
+
+#if !defined($header_guard) || defined(LTTNG_UST_TRACEPOINT_HEADER_MULTI_READ)
+#define $header_guard
+
+#include <lttng/tracepoint.h>
+
+$enum_definitions
+$tracepoint_definitions
+
+#endif /* $header_guard */
+
+#include <lttng/tracepoint-event.h>
+""")
+
+    interface_tpl = Template("""/* Auto-generated file! */
+#ifndef ${header_guard}_IMPL
+#define ${header_guard}_IMPL
+
+#include "${output_defs}"
+
+#endif /* ${header_guard}_IMPL */
+""")
+
+    tp_tpl = Template("""
+LTTNG_UST_TRACEPOINT_EVENT(
+    $provider,
+    enter_$name,
+    LTTNG_UST_TP_ARGS(
+        uint64_t, lttng_thread_id,
+        uint64_t, lttng_local_id$tp_args
+    ),
+    LTTNG_UST_TP_FIELDS(
+        lttng_ust_field_integer(uint64_t, lttng_thread_id, lttng_thread_id)
+        lttng_ust_field_integer(uint64_t, lttng_local_id, lttng_local_id)
+        $tp_fields
+    )
+)
+""")
+
+    tp_ret_tpl = Template("""
+LTTNG_UST_TRACEPOINT_EVENT(
+    $provider,
+    exit_$name,
+    LTTNG_UST_TP_ARGS(
+        uint64_t, lttng_thread_id,
+        uint64_t, lttng_local_id,
+        int, lttng_has_ret,
+        $ret_type, lttng_ret
+    ),
+    LTTNG_UST_TP_FIELDS(
+        lttng_ust_field_integer(uint64_t, lttng_thread_id, lttng_thread_id)
+        lttng_ust_field_integer(uint64_t, lttng_local_id, lttng_local_id)
+        lttng_ust_field_integer(int, lttng_has_ret, lttng_has_ret)
+        lttng_ust_field_integer($ret_type, lttng_ret, lttng_ret)
+    )
+)
+""")
+
+    tp_void_tpl = Template("""
+LTTNG_UST_TRACEPOINT_EVENT(
+    $provider,
+    exit_$name,
+    LTTNG_UST_TP_ARGS(
+        uint64_t, lttng_thread_id,
+        uint64_t, lttng_local_id,
+        int, lttng_has_ret
+    ),
+    LTTNG_UST_TP_FIELDS(
+        lttng_ust_field_integer(uint64_t, lttng_thread_id, lttng_thread_id)
+        lttng_ust_field_integer(uint64_t, lttng_local_id, lttng_local_id)
+        lttng_ust_field_integer(int, lttng_has_ret, lttng_has_ret)
+    )
+)
+""")
+    enum_tpl = Template("""
+LTTNG_UST_TRACEPOINT_ENUM($provider, $name,
+    LTTNG_UST_TP_ENUM_VALUES(
+        $values
+    )
+)
+""")
+    with open(output_defs, "w") as output:
+        definitions = []
+        for fn in function_declarations:
+            ret_type = fn.get_return_type_name()
+            definitions.append(tp_tpl.substitute(provider=PROVIDER,
+                                                 name=fn.name,
+                                                 tp_args=fn.tp_args(),
+                                                 tp_fields=fn.tp_fields()))
+            if ret_type == "void":
+                tpl = tp_void_tpl
+            else:
+                tpl = tp_ret_tpl
+
+            definitions.append(tpl.substitute(provider=PROVIDER,
+                                              name=fn.name,
+                                              ret_type=ret_type))
+
+        tracepoint_definitions = "\n".join(definitions)
+
+        enum_definitions = "\n".join([
+            enum_tpl.substitute(provider=PROVIDER,
+                                name=en.name,
+                                values="\n        ".join([f'lttng_ust_field_enum_value("{ev.name}", {ev.value})'
+                                                          for ev in en.values]))
+            for en in enum_declarations
+        ])
+
+        output.write(defs_tpl.substitute(provider=PROVIDER,
+                                         output_defs=output_defs,
+                                         header_guard=header_guard,
+                                         tracepoint_definitions=tracepoint_definitions,
+                                         enum_definitions=enum_definitions,
+                                         api_file=api_file,
+                                         pass_by_struct="".join([fn.arguments_struct()
+                                                                 for fn in function_declarations])))
+    with open(output_interface, "w") as output:
+        output.write(interface_tpl.substitute(header_guard=header_guard,
+                                              output_defs=output_defs,))
+
+def generate_tracepoint_classes(function_declarations, api_file, output_path, header_guard, namespace):
+    global_tpl = Template("""/* Auto-generated file! */
+#include <atomic>
+#include <cstdint>
+#include <$api_file>
+namespace $namespace {
+    struct unique_id {
+        uint64_t thread_id;
+        uint64_t local_id;
+    };
+
+    class id_generator {
+        static std::atomic<uint64_t> _thread_counter;
+        uint64_t _thread_id;
+        uint64_t _local_id;
+    public:
+        id_generator() {
+            _thread_id = _thread_counter++;
+            _local_id  = 0;
+        }
+
+        unique_id next_id() {
+            return {
+                .thread_id = _thread_id,
+                .local_id  = _local_id++,
+            };
+        }
+    };
+
+    extern thread_local id_generator generator;
+
+    template<typename RetType>
+    class base_api_object {
+    protected:
+        unique_id  _id;
+        int        _has_ret;
+        RetType    _ret;
+    public:
+        void generate_id() {
+            _id = generator.next_id();
+        }
+
+        void mark_return(RetType ret) {
+            _ret     = ret;
+            _has_ret = 1;
+        }
+    };
+
+    class base_api_object_void {
+    protected:
+        unique_id  _id;
+        int        _has_ret;
+    public:
+        void generate_id() {
+            _id = generator.next_id();
+        }
+
+        void mark_return(void) {
+            _has_ret = 1;
+        }
+    };
+
+$classes
+};
+""")
+
+    cls_ret_tpl = Template("""
+class api_object_$fn_name : public base_api_object<$ret_type>
+{
+public:
+    api_object_$fn_name($ctor_type_params) {
+        if (lttng_ust_tracepoint_enabled($provider, enter_$fn_name)) {
+            generate_id();
+            $pass_by_struct
+            lttng_ust_do_tracepoint($provider,
+                                    enter_$fn_name,
+                                    _id.thread_id,
+                                    _id.local_id$ctor_params);
+        }
+    }
+    ~api_object_$fn_name() {
+        if (lttng_ust_tracepoint_enabled($provider, exit_$fn_name)) {
+            lttng_ust_do_tracepoint($provider,
+                                    exit_$fn_name,
+                                    _id.thread_id,
+                                    _id.local_id,
+                                    _has_ret,
+                                    _ret);
+        }
+    }
+};
+""")
+
+    cls_void_tpl = Template("""
+class api_object_$fn_name : public base_api_object_void
+{
+public:
+    api_object_$fn_name($ctor_type_params) {
+        if (lttng_ust_tracepoint_enabled($provider, enter_$fn_name)) {
+            generate_id();
+            $pass_by_struct
+            lttng_ust_do_tracepoint($provider,
+                                    enter_$fn_name,
+                                    _id.thread_id,
+                                    _id.local_id$ctor_params);
+        }
+    }
+    ~api_object_$fn_name() {
+        if (lttng_ust_tracepoint_enabled($provider, exit_$fn_name)) {
+            lttng_ust_do_tracepoint($provider,
+                                    exit_$fn_name,
+                                    _id.thread_id,
+                                    _id.local_id,
+                                    _has_ret);
+        }
+    }
+};
+""")
+
+    with open(output_path, "w") as output:
+        classes = []
+        for fn in function_declarations:
+            ret_type = fn.get_return_type_name()
+            if ret_type == "void":
+                cls_tpl = cls_void_tpl
+            else:
+                cls_tpl = cls_ret_tpl
+            classes.append(cls_tpl.substitute(provider=PROVIDER,
+                                              fn_name=fn.name,
+                                              pass_by_struct=fn.arguments_struct_variable(),
+                                              ctor_type_params=", ".join([f"{arg.type_name()} {arg.name()}"
+                                                                          for arg in fn.args]),
+                                              ctor_params=fn.ctor_params(),
+                                              ret_type=ret_type))
+        output.write(global_tpl.substitute(api_file=api_file,
+                                           namespace=namespace,
+                                           classes="".join(classes)))
+
+def generate_tracepoint_emulated_classes(function_declarations, api_file, output_path,
+                                         header_guard, namespace):
+    global_tpl = Template("""/* Auto-generated file! */
+#include <stdint.h>
+#include <$api_file>
+#define ${NAMESPACE}_CAT_PRIMITIVE(A, B) A##B
+#define ${NAMESPACE}_CAT(A, B) ${NAMESPACE}_CAT_PRIMITIVE(A, B)
+
+struct ${namespace}_unique_id {
+       uint64_t thread_id;
+       uint64_t local_id;
+};
+
+struct ${namespace}_id_generator {
+       uint64_t thread_id;
+       uint64_t local_id;
+       int initialized;
+};
+
+extern uint64_t ${namespace}_id_generator_thread_counter;
+extern _Thread_local struct ${namespace}_id_generator ${namespace}_generator;
+
+#define ${namespace}_unlikely(x) __builtin_expect(!!(x), 0)
+
+static inline void ${namespace}_id_generator_next_id(struct ${namespace}_unique_id *id)
+{
+       if (${namespace}_unlikely(!${namespace}_generator.initialized)) {
+               ${namespace}_generator.thread_id =
+                       __atomic_fetch_add(&${namespace}_id_generator_thread_counter,
+                                          1,
+                                          __ATOMIC_RELAXED);
+               ${namespace}_generator.initialized = 1;
+       }
+
+       id->thread_id = ${namespace}_generator.thread_id;
+       id->local_id = ${namespace}_generator.local_id++;
+}
+
+#define ${NAMESPACE}_API_OBJECT_NAME ${namespace}_api_object
+
+#define ${NAMESPACE}_MAKE_API_OBJECT(name, ...) \\
+       struct ${NAMESPACE}_CAT(${namespace}_api_state_, name) __attribute__((cleanup(${NAMESPACE}_CAT(exit_, name)))) \\
+       ${NAMESPACE}_API_OBJECT_NAME = { 0 };  \\
+       ${NAMESPACE}_CAT(enter_, name)(&${NAMESPACE}_API_OBJECT_NAME, ##__VA_ARGS__); \\
+       do { } while (0)
+
+#define ${NAMESPACE}_MARK_RETURN_API_OBJECT(code) \\
+       ({                                                   \\
+               ${NAMESPACE}_API_OBJECT_NAME.ret = code;     \\
+               ${NAMESPACE}_API_OBJECT_NAME.has_ret = 1;    \\
+       })
+${classes}
+""")
+
+    cls_tpl = Template("""
+struct ${namespace}_api_state_${fn_name} {
+       struct ${namespace}_unique_id id;
+       int has_ret;
+       $ret_type ret;
+};
+
+static inline void enter_${fn_name}(${ctor_type_params})
+{
+       if (${namespace}_ust_tracepoint_enabled(${provider}, enter_${fn_name})) {
+               ${namespace}_id_generator_next_id(&lttng_state->id);
+               ${pass_by_struct}
+               ${namespace}_ust_do_tracepoint($provider, enter_${fn_name},
+                                              lttng_state->id.thread_id,
+                                              lttng_state->id.local_id${ctor_params});
+       }
+}
+
+static inline void exit_${fn_name}(const struct ${namespace}_api_state_${fn_name} *lttng_state)
+{
+       lttng_ust_tracepoint(${provider}, exit_${fn_name},
+                            lttng_state->id.thread_id,
+                            lttng_state->id.local_id,
+                            lttng_state->has_ret,
+                            lttng_state->ret);
+}
+""")
+    with open(output_path, "w") as output:
+        output.write(global_tpl.substitute(api_file=api_file,
+                                           namespace=namespace,
+                                           NAMESPACE=namespace.upper(),
+                                           classes="".join([
+                                               cls_tpl.substitute(provider=PROVIDER,
+                                                                  fn_name=fn.name,
+                                                                  pass_by_struct=fn.arguments_struct_variable(),
+                                                                  ctor_params=fn.ctor_params(),
+                                                                  ctor_type_params=", ".join([f"struct {namespace}_api_state_{fn.name} *lttng_state"] +
+                                                                                             [f"{arg.type_name()} {arg.name()}"
+                                                                                              for arg in fn.args]),
+                                                                  namespace=namespace,
+                                                                  NAMESPACE=namespace.upper(),
+                                                                  ret_type=fn.get_return_type_name())
+                                               for fn in function_declarations
+                                           ])))
+
+
+def generate_tracepoint_implementations(guard, namespace, defs, impls):
+    tpl = Template("""/* Auto-generated !*/
+#ifdef ${guard}
+
+#define LTTNG_UST_TRACEPOINT_CREATE_PROBES
+#define LTTNG_UST_TRACEPOINT_DEFINE
+#include "${defs}"
+
+#endif /* ${guard} */
+""")
+
+    with open(impls, "w") as output:
+        output.write(tpl.substitute(guard=guard,
+                                    defs=defs))
+
+def generate_tracepoint_states(states_guard,
+                               namespace,
+                               interface,
+                               classes,
+                               states,
+                               emulated_classes):
+
+    if emulated_classes:
+        body_tpl = Template("""
+uint64_t ${namespace}_id_generator_thread_counter = 0;
+_Thread_local struct ${namespace}_id_generator ${namespace}_generator;
+""")
+    else:
+        body_tpl = Template("""
+#include <atomic>
+namespace ${namespace} {
+       std::atomic<uint64_t> id_generator::_thread_counter{0};
+       thread_local id_generator generator;
+};
+""")
+
+    tpl = Template("""/* Auto-generated! */
+#ifdef ${states_guard}
+#include "${interface}"
+#include "${classes}"
+
+$body
+#endif
+""")
+
+    with open(states, "w") as output:
+        output.write(tpl.substitute(states_guard=states_guard,
+                                    interface=interface,
+                                    classes=classes,
+                                    body=body_tpl.substitute(namespace=namespace)))
+
+def main():
+
+    global COMMON_PREFIX
+    global IGNORE
+    global PROVIDER
+
+    parser = argparse.ArgumentParser(prog="lttng-ust-autogen-api",
+                                     description="Generate LTTng classes and tracepoint definitions")
+
+    parser.add_argument("api",
+                        help="Header file that has the API")
+
+    parser.add_argument("defs",
+                        help="Path to tracepoint definitions")
+
+    parser.add_argument("interface",
+                        help="Path to tracepoints interfaces")
+
+    parser.add_argument("classes",
+                        help="Path to tracepoint classes")
+
+    parser.add_argument("impl",
+                        help="Path to tracepoint implementations")
+
+    parser.add_argument("states",
+                        help="Path to states")
+
+    parser.add_argument("--provider",
+                        dest="provider",
+                        metavar="PROVIDER",
+                        default="noprovider",
+                        help="Tracepoints PROVIDER")
+
+    parser.add_argument("--common-prefix",
+                        dest="common_prefix",
+                        metavar="PREFIX",
+                        default="",
+                        help="Common PREFIX of API functions (C namespace)")
+
+    parser.add_argument("-I",
+                        action="append",
+                        metavar="DIR",
+                        dest="includes",
+                        help="Add DIR to list of directories to include")
+
+    parser.add_argument("-D",
+                        action="append",
+                        metavar="DEFINITION",
+                        dest="defines",
+                        help="Add DEFINITION to list of definitions")
+
+    parser.add_argument("--tp-guard",
+                        dest="tp_guard",
+                        metavar="GUARD",
+                        default="LTTNG_TRACEPOINT_DEF_H",
+                        help="Use GUARD as header guard for tracepoint definitions")
+
+    parser.add_argument("--classes-guard",
+                        dest="classes_guard",
+                        metavar="GUARD",
+                        default="LTTNG_TRACEPOINT_CLASSES_HPP",
+                        help="Use GUARD as header guard for classes definitions")
+
+    parser.add_argument("--impl-guard",
+                        dest="impl_guard",
+                        metavar="GUARD",
+                        default="ENABLE_LTTNG_TRACEPOINTS",
+                        help="Use GUARD around implementations")
+
+    parser.add_argument("--states-guard",
+                        dest="states_guard",
+                        metavar="GUARD",
+                        default="ENABLE_LTTNG_TRACEPOINTS",
+                        help="Use GUARD around states")
+
+    parser.add_argument("--emulated-classes",
+                        dest="emulated_classes",
+                        action="store_true",
+                        default=False,
+                        help="Emulate C++ classes")
+
+    parser.add_argument("--namespace",
+                        dest="namespace",
+                        metavar="NAMESPACE",
+                        default="lttng",
+                        help="Generate classes in NAMESPACE")
+
+    parser.add_argument("--ignore",
+                        dest="ignore",
+                        metavar="FUNCTION",
+                        action="append",
+                        default=[],
+                        help="Ignore FUNCTION")
+
+    parser.add_argument("--c-header",
+                        dest="required_c_headers",
+                        metavar="HEADER",
+                        action="append",
+                        default=[],
+                        help="Search for HEADER in C_INCLUDE_PATH and add its directory to search path")
+
+    parser.add_argument("--cxx-header",
+                        dest="required_cxx_headers",
+                        metavar="HEADER",
+                        action="append",
+                        default=[],
+                        help="Search for HEADER in CPLUS_INCLUDE_PATH add its directory to search path")
+
+    args = parser.parse_args()
+
+    PROVIDER      = args.provider
+    COMMON_PREFIX = args.common_prefix
+    IGNORE        = set(args.ignore)
+
+    root = parse_header(args.api, args.includes, args.defines,
+                        args.required_c_headers,
+                        args.required_cxx_headers)
+
+    function_declarations = list_functions(root)
+    enum_declarations     = list_enums(root)
+
+    generate_tracepoint_definitions(function_declarations,
+                                    enum_declarations,
+                                    args.api, args.defs, args.interface,
+                                    args.tp_guard)
+
+    if args.emulated_classes:
+        generate_tracepoint_emulated_classes(function_declarations,
+                                             args.api,
+                                             args.classes,
+                                             args.classes_guard,
+                                             args.namespace)
+    else:
+        generate_tracepoint_classes(function_declarations,
+                                    args.api,
+                                    args.classes,
+                                    args.classes_guard,
+                                    args.namespace)
+
+    generate_tracepoint_implementations(args.impl_guard,
+                                        args.namespace,
+                                        args.interface,
+                                        args.impl)
+
+    generate_tracepoint_states(args.states_guard,
+                               args.namespace,
+                               args.interface,
+                               args.classes,
+                               args.states,
+                               args.emulated_classes)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/lttng/ust-context-provider.h b/lttng/ust-context-provider.h
new file mode 100644 (file)
index 0000000..6857c00
--- /dev/null
@@ -0,0 +1,124 @@
+/*
+ * SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2016 Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
+ *
+ * The context provider feature is part of the ABI and used by the Java jni
+ * interface. This header should be moved to the public header directory once
+ * some test code and documentation is written.
+ */
+
+#ifndef _LTTNG_UST_CONTEXT_PROVIDER_H
+#define _LTTNG_UST_CONTEXT_PROVIDER_H
+
+#include <stddef.h>
+#include <lttng/ust-events.h>
+
+enum lttng_ust_dynamic_type {
+       LTTNG_UST_DYNAMIC_TYPE_NONE,
+       LTTNG_UST_DYNAMIC_TYPE_S8,
+       LTTNG_UST_DYNAMIC_TYPE_S16,
+       LTTNG_UST_DYNAMIC_TYPE_S32,
+       LTTNG_UST_DYNAMIC_TYPE_S64,
+       LTTNG_UST_DYNAMIC_TYPE_U8,
+       LTTNG_UST_DYNAMIC_TYPE_U16,
+       LTTNG_UST_DYNAMIC_TYPE_U32,
+       LTTNG_UST_DYNAMIC_TYPE_U64,
+       LTTNG_UST_DYNAMIC_TYPE_FLOAT,
+       LTTNG_UST_DYNAMIC_TYPE_DOUBLE,
+       LTTNG_UST_DYNAMIC_TYPE_STRING,
+       _NR_LTTNG_UST_DYNAMIC_TYPES,
+};
+
+int lttng_ust_dynamic_type_choices(size_t *nr_choices,
+               const struct lttng_ust_event_field * const **choices)
+       __attribute__((visibility("hidden")));
+
+const struct lttng_ust_event_field *lttng_ust_dynamic_type_field(int64_t value)
+       __attribute__((visibility("hidden")));
+
+const struct lttng_ust_event_field *lttng_ust_dynamic_type_tag_field(void)
+       __attribute__((visibility("hidden")));
+
+struct lttng_ust_registered_context_provider;
+struct lttng_ust_probe_ctx;
+
+/*
+ * Context value
+ *
+ * IMPORTANT: this structure is part of the ABI between the probe and
+ * UST. Additional selectors may be added in the future, mapping to new
+ * union fields, which means the overall size of this structure may
+ * increase. This means this structure should never be nested within a
+ * public structure interface, nor embedded in an array.
+ */
+
+struct lttng_ust_ctx_value {
+       enum lttng_ust_dynamic_type sel;        /* Type selector */
+       union {
+               int64_t s64;
+               uint64_t u64;
+               const char *str;
+               double d;
+       } u;
+};
+
+/*
+ * Context provider
+ *
+ * IMPORTANT: this structure is part of the ABI between the probe and
+ * UST. Fields need to be only added at the end, never reordered, never
+ * removed.
+ *
+ * The field @struct_size should be used to determine the size of the
+ * structure. It should be queried before using additional fields added
+ * at the end of the structure.
+ */
+
+struct lttng_ust_context_provider {
+       uint32_t struct_size;
+
+       const char *name;
+       size_t (*get_size)(void *priv, struct lttng_ust_probe_ctx *probe_ctx,
+                       size_t offset);
+       void (*record)(void *priv, struct lttng_ust_probe_ctx *probe_ctx,
+                       struct lttng_ust_ring_buffer_ctx *ctx,
+                       struct lttng_ust_channel_buffer *chan);
+       void (*get_value)(void *priv, struct lttng_ust_probe_ctx *probe_ctx,
+                       struct lttng_ust_ctx_value *value);
+       void *priv;
+
+       /* End of base ABI. Fields below should be used after checking struct_size. */
+};
+
+/*
+ * Application context callback private data
+ *
+ * IMPORTANT: this structure is part of the ABI between the probe and
+ * UST. Fields need to be only added at the end, never reordered, never
+ * removed.
+ *
+ * The field @struct_size should be used to determine the size of the
+ * structure. It should be queried before using additional fields added
+ * at the end of the structure.
+ */
+
+struct lttng_ust_app_context {
+       uint32_t struct_size;
+
+       struct lttng_ust_event_field *event_field;
+       char *ctx_name;
+
+       /* End of base ABI. Fields below should be used after checking struct_size. */
+};
+
+/*
+ * Returns an opaque pointer on success, which must be passed to
+ * lttng_ust_context_provider_unregister for unregistration. Returns
+ * NULL on error.
+ */
+struct lttng_ust_registered_context_provider *lttng_ust_context_provider_register(struct lttng_ust_context_provider *provider);
+
+void lttng_ust_context_provider_unregister(struct lttng_ust_registered_context_provider *reg_provider);
+
+#endif /* _LTTNG_UST_CONTEXT_PROVIDER_H */
diff --git a/run-test-mpi b/run-test-mpi
new file mode 100755 (executable)
index 0000000..6b0eca9
--- /dev/null
@@ -0,0 +1,2 @@
+#!/bin/sh
+env LD_PRELOAD=./liblttng-ust-mpi.so ./test-mpi $1
diff --git a/test-mpi.c b/test-mpi.c
new file mode 100644 (file)
index 0000000..ed1c9d1
--- /dev/null
@@ -0,0 +1,143 @@
+/*
+ * SPDX-License-Identifier: MIT
+ *
+ * Copyright (c) 2023 Olivier Dion <odion@efficios.com>
+ */
+
+#include <assert.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <mpi.h>
+
+static uint64_t sum_of(uint64_t *values, size_t values_count)
+{
+       size_t acc = 0;
+       for (size_t k=0; k<values_count; ++k) {
+               acc += values[k];
+       }
+       return acc;
+}
+
+static void usage()
+{
+       fprintf(stderr, "Usage: test-mpi N\n");
+       exit(EXIT_FAILURE);
+}
+
+static uint64_t *allocate_values(size_t upto)
+{
+       uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * upto);
+       for (size_t k=0; k<upto; ++k) {
+               values[k] = k + 1;
+       }
+       return values;
+}
+
+static void send_values(int target, uint64_t *values,
+                        size_t values_count,
+                        MPI_Request *request)
+{
+       MPI_Isend(values, values_count, MPI_UINT64_T,
+                 target, 0, MPI_COMM_WORLD, request);
+}
+
+static void recv_answer(int target, uint64_t *value,
+                        MPI_Request *request)
+{
+       MPI_Irecv(value, 1, MPI_UINT64_T,
+                 target, 0, MPI_COMM_WORLD, request);
+}
+
+static void send_answer(uint64_t value)
+{
+       MPI_Send(&value, 1, MPI_UINT64_T,
+                0, 0, MPI_COMM_WORLD);
+}
+
+static uint64_t *recv_values(size_t chunk_size)
+{
+       uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * chunk_size);
+       MPI_Recv(values, chunk_size, MPI_UINT64_T,
+                0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+       return values;
+}
+
+int main(int argc, char *argv[])
+{
+       int rank;
+       int size;
+       long long upto;
+       uint64_t *values;
+
+       if (argc < 2) {
+               usage();
+       }
+
+       upto = atoll(argv[1]);
+
+       if (upto <= 0) {
+               fprintf(stderr, "N must be greater than 0\n");
+               exit(EXIT_FAILURE);
+       }
+
+       MPI_Init(&argc, &argv);
+
+       MPI_Comm_set_errhandler(MPI_COMM_WORLD,
+                               MPI_ERRORS_RETURN);
+
+       MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+       MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+       size_t chunk_size;
+       size_t rest;
+       uint64_t total;
+
+       if (size > 1) {
+               chunk_size = upto / (size - 1);
+               rest = upto % (size - 1);
+       } else {
+               chunk_size = 0;
+               rest = upto;
+       }
+
+       if (rank == 0) {
+               uint64_t sums[size];
+               MPI_Request requests[size - 1];
+
+               values = allocate_values(upto);
+
+               for (int k=1; k<size; ++k) {
+                       send_values(k,
+                                   values + (chunk_size * (k - 1)),
+                                   chunk_size,
+                                   &requests[k-1]);
+               }
+
+               sums[0] = sum_of(values + chunk_size * (size - 1),
+                                rest);
+
+               MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
+
+               for (int k=1; k<size; ++k) {
+                       recv_answer(k, &sums[k], &requests[k-1]);
+               }
+
+               MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
+
+               total = sum_of(sums, size);
+       } else {
+               send_answer(sum_of(recv_values(chunk_size),
+                                  chunk_size));
+       }
+
+       MPI_Finalize();
+
+       if (rank == 0){
+               assert(total ==
+                      (((uint64_t)upto * ((uint64_t)upto + 1U)) >> 1U));
+       }
+
+       return 0;
+}
diff --git a/test.c b/test.c
new file mode 100644 (file)
index 0000000..c44d13f
--- /dev/null
+++ b/test.c
@@ -0,0 +1,144 @@
+/*
+ * SPDX-License-Identifier: MIT
+ *
+ * Copyright (c) 2023 Olivier Dion <odion@efficios.com>
+ */
+
+#include <assert.h>
+#include <stdint.h>
+#include <stdio.h>
+
+#include <mpi.h>
+
+static uint64_t sum_of(uint64_t *values, size_t values_count)
+{
+    size_t acc = 0;
+    for (size_t k=0; k<values_count; ++k) {
+        acc += values[k];
+    }
+    return acc;
+}
+
+static void usage()
+{
+    fprintf(stderr, "Usage: test-mpi N\n");
+    exit(EXIT_FAILURE);
+}
+
+static uint64_t *allocate_values(size_t upto)
+{
+    uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * upto);
+    for (size_t k=0; k<upto; ++k) {
+        values[k] = k + 1;
+    }
+    return values;
+}
+
+static void send_values(int target, uint64_t *values,
+                        size_t values_count,
+                        MPI_Request *request)
+{
+    MPI_Isend(values, values_count, MPI_UINT64_T,
+              target, 0, MPI_COMM_WORLD, request);
+}
+
+static void recv_answer(int target, uint64_t *value,
+                        MPI_Request *request)
+{
+    MPI_Irecv(value, 1, MPI_UINT64_T,
+              target, 0, MPI_COMM_WORLD, request);
+}
+
+static void send_answer(uint64_t value)
+{
+    MPI_Send(&value, 1, MPI_UINT64_T,
+             0, 0, MPI_COMM_WORLD);
+}
+
+static uint64_t *recv_values(size_t chunk_size)
+{
+    uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * chunk_size);
+    MPI_Recv(values, chunk_size, MPI_UINT64_T,
+             0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+    return values;
+}
+
+int main(int argc, char *argv[])
+{
+    int rank;
+    int size;
+    long long upto;
+    uint64_t *values;
+
+    if (argc < 2) {
+        usage();
+    }
+
+    upto = atoll(argv[1]);
+
+    if (upto <= 0) {
+        fprintf(stderr, "N must be greater than 0\n");
+        exit(EXIT_FAILURE);
+    }
+
+    cali_init();
+
+    MPI_Init(&argc, &argv);
+
+    MPI_Comm_set_errhandler(MPI_COMM_WORLD,
+                            MPI_ERRORS_RETURN);
+
+    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+    MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+    size_t chunk_size;
+    size_t rest;
+    uint64_t total;
+
+    if (size > 1) {
+        chunk_size = upto / (size - 1);
+        rest = upto % (size - 1);
+    } else {
+        chunk_size = 0;
+        rest = upto;
+    }
+
+    if (rank == 0) {
+        uint64_t sums[size];
+        MPI_Request requests[size - 1];
+
+        values = allocate_values(upto);
+
+        for (int k=1; k<size; ++k) {
+            send_values(k,
+                        values + (chunk_size * (k - 1)),
+                        chunk_size,
+                        &requests[k-1]);
+        }
+
+        sums[0] = sum_of(values + chunk_size * (size - 1),
+                         rest);
+
+        MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
+
+        for (int k=1; k<size; ++k) {
+            recv_answer(k, &sums[k], &requests[k-1]);
+        }
+
+        MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
+
+        total = sum_of(sums, size);
+    } else {
+        send_answer(sum_of(recv_values(chunk_size),
+                           chunk_size));
+    }
+
+    MPI_Finalize();
+
+    if (rank == 0){
+        assert(total ==
+               (((uint64_t)upto * ((uint64_t)upto + 1U)) >> 1U));
+    }
+
+    return 0;
+}
This page took 0.05159 seconds and 4 git commands to generate.