Add explicit MIT license and copyright
[deliverable/exatracer.git] / scripts / gen-hsa-wrappers
CommitLineData
74c0b9b3
OD
1#!/usr/bin/env python3
2#
c38914e1
MJ
3# SPDX-FileCopyrightText: 2023 EfficiOS, Inc.
4#
5# SPDX-License-Identifier: MIT
74c0b9b3
OD
6#
7# Author: Olivier Dion <odion@efficios.com>
8#
c75d6f3f
OD
9# Auto-generate lttng-ust wrappers for HSA.
10#
11# Please refer to gen-hip-wrappers. This is basically the same thing but
12# adapted to HSA.
74c0b9b3
OD
13#
14# Require: python-clang (libclang)
15
16import argparse
17import re
18import subprocess
19
20from string import Template
21
22import clang.cindex
23
24def list_function_declarations(root):
25 return [
26 child
27 for child in root.get_children()
28 if child.kind == clang.cindex.CursorKind.FUNCTION_DECL
29 ]
30
31def get_system_include_paths():
32
33 clang_args = ["clang", "-v", "-c", "-xc", "-o", "/dev/null", "/dev/null"]
34 paths = []
35
36 with subprocess.Popen(clang_args, stderr=subprocess.PIPE, text=True) as proc:
37 start_sys_search = False
38 for line in proc.stderr:
39 if start_sys_search:
40 if line == "End of search list.\n":
41 break
42 paths.append("-isystem")
43 paths.append(line.strip())
44 elif line == "#include <...> search starts here:\n":
45 start_sys_search = True
46
47 return paths
48
49def parse_header(header_file, includes, defines):
50
51 args = get_system_include_paths()
52
53 if includes:
54 for inc in includes:
55 args.append("-I")
56 args.append(inc)
57
58 if defines:
59 for d in defines:
60 args.append("-D")
61 args.append(d)
62
63 tu = clang.cindex.Index.create().parse(header_file, args=args)
64
65 for d in tu.diagnostics:
66 print(f"WARNING: {d}")
67
68 return tu.cursor
69
70def list_functions(root):
71 return [
72 fn
73 for fn in list_function_declarations(root)
74 if fn.spelling.startswith("hsa") and fn.spelling
75 ]
76
77def exact_definition(arg):
78
79 ct = arg.type.get_canonical()
80 if ct.kind == clang.cindex.TypeKind.POINTER:
81 pt = ct.get_pointee()
82 if pt.kind == clang.cindex.TypeKind.FUNCTIONPROTO:
83 ret_type = pt.get_result().spelling
84 argument_types = ", ".join([a.spelling for a in pt.argument_types()])
85 return f"{ret_type} (*{arg.spelling})({argument_types})"
86 m = re.search(r'(\[[0-9]*\])+', arg.type.spelling)
87 if m:
88 return f"{arg.type.spelling[:m.start(0)]} {arg.spelling}{m.group(0)}"
89 else:
90 return f"{arg.type.spelling} {arg.spelling}"
91
92def cast(arg):
93 canon = arg.type.get_canonical()
94 if canon.kind == clang.cindex.TypeKind.POINTER:
95 return "void *"
96 return re.sub(r'\[[0-9]*\]', '*', canon.spelling)
97
74c0b9b3 98
74c0b9b3
OD
99
100def main():
101
c75d6f3f
OD
102 extra_works = {
103 }
104
105 forbiden_list = set()
106
74c0b9b3
OD
107 parser = argparse.ArgumentParser(prog="gen-hsa-wrappers")
108
109 parser.add_argument("api",
110 help="HSA API header")
111
112 parser.add_argument("wrappers",
113 help="Path to HSA wrappers")
114
115 parser.add_argument("--ignores",
116 dest="ignores",
117 metavar="FILE",
118 default=None,
119 help="Ignore list")
120
121 parser.add_argument("-I",
122 action="append",
123 metavar="DIR",
124 dest="includes",
125 help="Add DIR to list of directories to include")
126
127 parser.add_argument("-D",
128 action="append",
129 metavar="DEFINITION",
130 dest="defines",
131 help="Add DEFINITION to list of definitions")
132
133 args = parser.parse_args()
134
135 if args.ignores:
136 with open(args.ignores, "r") as f:
137 for ignore in f.read().splitlines():
138 forbiden_list.add(ignore)
139
140 prologue_tpl = Template("""/* Auto-generated */
141#include "lttng-ust-hsa-states.h"
142""")
143
144 ret_fn_tpl = Template("""
145static ${ret_type} lttng_${fn_name}(${fn_arguments})
146{
147 ${ret_type} ret;
148 {
149 lttng_hsa::api_object_${fn_name} lttng_api_object {${fn_rest_argument_names}};
150 ret = next_hsa_core_table.${fn_name}_fn(${fn_pass_argument_names});
151 lttng_api_object.mark_return(ret);
152 }
153$extra_work
154 return ret;
155}
156""")
157
158 void_fn_tpl = Template("""
159static void lttng_${fn_name}(${fn_arguments})
160{
161 {
162 lttng_hsa::api_object_${fn_name} lttng_api_object {${fn_rest_argument_names}};
163 next_hsa_core_table.${fn_name}_fn(${fn_pass_argument_names});
164 lttng_api_object.mark_return();
165 }
166$extra_work
167}
168""")
169
170 epilogue_tpl = Template("""
171static void lttng_hsa_install_wrappers(void)
172{
173 ${wrappers}
174}
175""")
176
177 functions = list_functions(parse_header(args.api,
178 args.includes,
179 args.defines))
180
181 with open(args.wrappers, "w") as output:
182
183 output.write(prologue_tpl.substitute())
184
185 for fn in functions:
186
187 if fn.spelling in forbiden_list:
188 continue
189
190 args = list(fn.get_arguments())
191 fn_pass_argument_names = ", ".join([
192 f"{arg.spelling}"
193 for arg in args
194 ])
195
196 if args:
197 fn_rest_argument_names = ", ".join([
198 "(%s)%s" % (cast(arg), arg.spelling)
199 for arg in args
200 ])
201 else:
202 fn_rest_argument_names=""
203
204 if fn.spelling in extra_works:
205 extra_work = extra_works[fn.spelling]
206 else:
207 extra_work = ""
208
209 if "void"== fn.type.get_result().spelling:
210 fn_tpl = void_fn_tpl
211 else:
212 fn_tpl = ret_fn_tpl
213
214 output.write(fn_tpl.substitute(fn_name=fn.spelling,
215 fn_arguments=", ".join([
216 exact_definition(arg)
217 for arg in fn.get_arguments()
218 ]),
219 fn_pass_argument_names=fn_pass_argument_names,
220 fn_rest_argument_names=fn_rest_argument_names,
221 ret_type=fn.type.get_result().spelling,
222 extra_work=extra_work))
223
224 output.write(epilogue_tpl.substitute(wrappers="\n ".join([
225 f"lttng_hsa_core_table.{fn.spelling}_fn = &lttng_{fn.spelling};"
226 for fn in functions if fn.spelling not in forbiden_list
227 ])))
228
229
230if __name__ == "__main__":
231 main()
This page took 0.034388 seconds and 4 git commands to generate.