Add explicit MIT license and copyright
[deliverable/lttng-ust-mpi.git] / lttng-auto-mpi-wrappers
1 #!/usr/bin/env python3
2 #
3 # SPDX-License-Identifier: MIT
4 #
5 # SPDX-FileCopyrightText: 2023 EfficiOS, Inc.
6 #
7 # Author: Olivier Dion <odion@efficios.com>
8 #
9 # Auto-generate lttng-ust tracepoints for OpenMPI.
10 #
11 # Require: python-clang (libclang)
12
13 import argparse
14 import re
15
16 from string import Template
17
18 import clang.cindex
19
20 def list_function_declarations(root):
21 return [ child
22 for child in root.get_children()
23 if child.kind == clang.cindex.CursorKind.FUNCTION_DECL ]
24
25 def parse_header(header_file):
26 return clang.cindex.Index.create().parse(header_file).cursor
27
28 def list_functions(root):
29 return [
30 fn
31 for fn in list_function_declarations(root)
32 if fn.spelling.startswith("MPI_") and fn.spelling
33 ]
34
35 def exact_definition(arg):
36 m = re.search(r'(\[[0-9]*\])+', arg.type.spelling)
37 if m:
38 return f"{arg.type.spelling[:m.start(0)]} {arg.spelling}{m.group(0)}"
39 else:
40 return f"{arg.type.spelling} {arg.spelling}"
41
42 forbiden_list = {
43 "MPI_Pcontrol"
44 }
45
46 extra_works = {
47 "MPI_Init": """
48 if (MPI_SUCCESS == ret) {
49 int (*mpi_comm_rank)(MPI_Comm, int *rank);
50 MPI_Comm mpi_comm_world;
51 #ifdef CRAY_MPICH_VERSION
52 mpi_comm_world = MPI_COMM_WORLD;
53 #else
54 mpi_comm_world = *(void**)resolve_or_die("ompi_mpi_comm_world_addr");
55 #endif
56 mpi_comm_rank = resolve_or_die("PMPI_Comm_rank");
57 mpi_comm_rank(mpi_comm_world, &mpi_rank);
58 mpi_provider.priv = lttng_ust_context_provider_register(&mpi_provider);
59 }
60 """,
61 "MPI_Finalize": """
62 if (mpi_provider.priv) {
63 lttng_ust_context_provider_unregister(mpi_provider.priv);
64 }
65 """,
66 }
67
68 def main():
69
70 parser = argparse.ArgumentParser(prog="lttng-ust-auto-mpi")
71
72 parser.add_argument("api",
73 help="MPI API header")
74
75 parser.add_argument("wrappers",
76 help="Path to MPI wrappers")
77
78 args = parser.parse_args()
79
80 fn_tpl = Template("""
81 ${ret_type} ${fn_name}(${fn_arguments})
82 {
83 ${ret_type} ret;
84 {
85 static ${ret_type}(*real_fn)(${fn_arguments}) = NULL;
86 if (unlikely(NULL == __atomic_load_n(&real_fn, __ATOMIC_RELAXED))) {
87 void *result = resolve_or_die("P${fn_name}");
88 __atomic_store_n(&real_fn, result, __ATOMIC_RELAXED);
89 }
90 LTTNG_MAKE_API_OBJECT(${fn_name}${fn_rest_argument_names});
91 ret = real_fn(${fn_pass_argument_names});
92 LTTNG_MARK_RETURN_API_OBJECT(ret);
93 }
94 $extra_work
95 return ret;
96 }
97 """)
98
99 with open(args.wrappers, "w") as output:
100 output.write("""/* Auto-generated */
101 #define _GNU_SOURCE
102
103 #include <stdlib.h>
104 #include <string.h>
105
106 #include <dlfcn.h>
107
108 #include <mpi.h>
109
110 #include <lttng/ust-events.h>
111 #include <lttng/ust-ringbuffer-context.h>
112
113 #include "lttng/ust-context-provider.h"
114
115 #include "lttng-ust-mpi-states.h"
116
117 #define likely(x) __builtin_expect(!!(x), 1)
118 #define unlikely(x) __builtin_expect(!!(x), 0)
119
120 #define die(fmt, ...) \\
121 do { \\
122 fprintf(stderr, fmt "\\n", ##__VA_ARGS__); \\
123 exit(EXIT_FAILURE); \\
124 } while (0)
125
126 static void *resolve_or_die(const char *symbol)
127 {
128 void *ret = dlsym(RTLD_NEXT, symbol);
129 if (unlikely(!ret)) {
130 die("could not resolve `%s': %s", symbol, dlerror());
131 }
132 return ret;
133 }
134
135 static inline int streq(const char *A, const char *B)
136 {
137 return 0 == strcmp(A, B);
138 }
139
140 static inline char *context_type(struct lttng_ust_app_context *uctx)
141 {
142 char *suffix = index(uctx->ctx_name, ':');
143
144 if (likely(suffix)) {
145 suffix = &suffix[1]; /* Skip ':' */
146 }
147
148 return suffix;
149 }
150
151 static int mpi_rank = -1;
152
153 static size_t mpi_provider_get_size(void *uctx,
154 struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
155 size_t offset)
156 {
157 size_t size = 0;
158 char *type = context_type(uctx);
159
160 size += lttng_ust_ring_buffer_align(offset, lttng_ust_rb_alignof(char));
161 size += sizeof(char);
162
163 if (unlikely(!type)) {
164 goto error;
165 }
166
167 if (streq(type, "rank")) {
168 size += lttng_ust_ring_buffer_align(offset, lttng_ust_rb_alignof(int64_t));
169 size += sizeof(int64_t);
170
171 } else {
172 error:
173 /* Unknown context. */
174 (void) size;
175 }
176
177 return size;
178 }
179
180 static void mpi_provider_record(void *uctx,
181 struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
182 struct lttng_ust_ring_buffer_ctx *ctx,
183 struct lttng_ust_channel_buffer *lttng_chan_buf)
184 {
185 int sel;
186 char sel_char;
187 char *type = context_type(uctx);
188
189 if (unlikely(!type)) {
190 goto error;
191 }
192
193 if (streq(type, "rank")) {
194 int64_t v;
195 sel = LTTNG_UST_DYNAMIC_TYPE_S64;
196 sel_char = (char) sel;
197 v = (int64_t) mpi_rank;
198 lttng_chan_buf->ops->event_write(ctx, &sel_char, sizeof(sel_char),
199 lttng_ust_rb_alignof(char));
200
201 lttng_chan_buf->ops->event_write(ctx, &v, sizeof(v), lttng_ust_rb_alignof(v));
202 } else {
203 error:
204 sel = LTTNG_UST_DYNAMIC_TYPE_NONE;
205 sel_char = (char) sel;
206 lttng_chan_buf->ops->event_write(ctx, &sel_char, sizeof(sel_char),
207 lttng_ust_rb_alignof(char));
208 }
209 }
210
211 static void mpi_provider_get_value(void *uctx,
212 struct lttng_ust_probe_ctx *probe_ctx __attribute__((unused)),
213 struct lttng_ust_ctx_value *value)
214 {
215 char *type = context_type(uctx);
216
217 if (unlikely(!type)) {
218 goto error;
219 }
220
221 if (streq(type, "rank")) {
222 value->sel = LTTNG_UST_DYNAMIC_TYPE_S64;
223 value->u.s64 = (int64_t) mpi_rank;
224 } else {
225 error:
226 value->sel = LTTNG_UST_DYNAMIC_TYPE_NONE;
227 }
228 }
229
230 static struct lttng_ust_context_provider mpi_provider = {
231 .struct_size = sizeof(struct lttng_ust_context_provider),
232 .name = "$app.MPI",
233 .get_size = mpi_provider_get_size,
234 .record = mpi_provider_record,
235 .get_value = mpi_provider_get_value
236 };
237 """)
238 for fn in list_functions(parse_header(args.api)):
239
240 if fn.spelling in forbiden_list:
241 continue
242
243 args = list(fn.get_arguments())
244 fn_pass_argument_names = ", ".join([
245 f"{arg.spelling}"
246 for arg in args
247 ])
248
249 if args:
250 fn_rest_argument_names = ", " + ", ".join([
251 "(%s)%s" % (re.sub(r'\[[0-9]*\]', '*', arg.type.spelling),
252 arg.spelling)
253 for arg in args
254 ])
255 else:
256 fn_rest_argument_names=""
257
258 if fn.spelling in extra_works:
259 extra_work = extra_works[fn.spelling]
260 else:
261 extra_work = ""
262
263 output.write(fn_tpl.substitute(fn_name=fn.spelling,
264 fn_arguments=", ".join([
265 exact_definition(arg)
266 for arg in fn.get_arguments()
267 ]),
268 fn_pass_argument_names=fn_pass_argument_names,
269 fn_rest_argument_names=fn_rest_argument_names,
270 ret_type=fn.type.get_result().spelling,
271 extra_work=extra_work))
272
273 if __name__ == "__main__":
274 main()
This page took 0.035819 seconds and 4 git commands to generate.