Commit | Line | Data |
---|---|---|
74c0b9b3 OD |
1 | #!/usr/bin/env python3 |
2 | # | |
c38914e1 MJ |
3 | # SPDX-FileCopyrightText: 2023 EfficiOS, Inc. |
4 | # | |
5 | # SPDX-License-Identifier: MIT | |
74c0b9b3 OD |
6 | # |
7 | # Author: Olivier Dion <odion@efficios.com> | |
8 | # | |
c75d6f3f OD |
9 | # Auto-generate lttng-ust wrappers for HSA. |
10 | # | |
11 | # Please refer to gen-hip-wrappers. This is basically the same thing but | |
12 | # adapted to HSA. | |
74c0b9b3 OD |
13 | # |
14 | # Require: python-clang (libclang) | |
15 | ||
16 | import argparse | |
17 | import re | |
18 | import subprocess | |
19 | ||
20 | from string import Template | |
21 | ||
22 | import clang.cindex | |
23 | ||
24 | def list_function_declarations(root): | |
25 | return [ | |
26 | child | |
27 | for child in root.get_children() | |
28 | if child.kind == clang.cindex.CursorKind.FUNCTION_DECL | |
29 | ] | |
30 | ||
31 | def get_system_include_paths(): | |
32 | ||
33 | clang_args = ["clang", "-v", "-c", "-xc", "-o", "/dev/null", "/dev/null"] | |
34 | paths = [] | |
35 | ||
36 | with subprocess.Popen(clang_args, stderr=subprocess.PIPE, text=True) as proc: | |
37 | start_sys_search = False | |
38 | for line in proc.stderr: | |
39 | if start_sys_search: | |
40 | if line == "End of search list.\n": | |
41 | break | |
42 | paths.append("-isystem") | |
43 | paths.append(line.strip()) | |
44 | elif line == "#include <...> search starts here:\n": | |
45 | start_sys_search = True | |
46 | ||
47 | return paths | |
48 | ||
49 | def parse_header(header_file, includes, defines): | |
50 | ||
51 | args = get_system_include_paths() | |
52 | ||
53 | if includes: | |
54 | for inc in includes: | |
55 | args.append("-I") | |
56 | args.append(inc) | |
57 | ||
58 | if defines: | |
59 | for d in defines: | |
60 | args.append("-D") | |
61 | args.append(d) | |
62 | ||
63 | tu = clang.cindex.Index.create().parse(header_file, args=args) | |
64 | ||
65 | for d in tu.diagnostics: | |
66 | print(f"WARNING: {d}") | |
67 | ||
68 | return tu.cursor | |
69 | ||
70 | def list_functions(root): | |
71 | return [ | |
72 | fn | |
73 | for fn in list_function_declarations(root) | |
74 | if fn.spelling.startswith("hsa") and fn.spelling | |
75 | ] | |
76 | ||
77 | def exact_definition(arg): | |
78 | ||
79 | ct = arg.type.get_canonical() | |
80 | if ct.kind == clang.cindex.TypeKind.POINTER: | |
81 | pt = ct.get_pointee() | |
82 | if pt.kind == clang.cindex.TypeKind.FUNCTIONPROTO: | |
83 | ret_type = pt.get_result().spelling | |
84 | argument_types = ", ".join([a.spelling for a in pt.argument_types()]) | |
85 | return f"{ret_type} (*{arg.spelling})({argument_types})" | |
86 | m = re.search(r'(\[[0-9]*\])+', arg.type.spelling) | |
87 | if m: | |
88 | return f"{arg.type.spelling[:m.start(0)]} {arg.spelling}{m.group(0)}" | |
89 | else: | |
90 | return f"{arg.type.spelling} {arg.spelling}" | |
91 | ||
92 | def cast(arg): | |
93 | canon = arg.type.get_canonical() | |
94 | if canon.kind == clang.cindex.TypeKind.POINTER: | |
95 | return "void *" | |
96 | return re.sub(r'\[[0-9]*\]', '*', canon.spelling) | |
97 | ||
74c0b9b3 | 98 | |
74c0b9b3 OD |
99 | |
100 | def main(): | |
101 | ||
c75d6f3f OD |
102 | extra_works = { |
103 | } | |
104 | ||
105 | forbiden_list = set() | |
106 | ||
74c0b9b3 OD |
107 | parser = argparse.ArgumentParser(prog="gen-hsa-wrappers") |
108 | ||
109 | parser.add_argument("api", | |
110 | help="HSA API header") | |
111 | ||
112 | parser.add_argument("wrappers", | |
113 | help="Path to HSA wrappers") | |
114 | ||
115 | parser.add_argument("--ignores", | |
116 | dest="ignores", | |
117 | metavar="FILE", | |
118 | default=None, | |
119 | help="Ignore list") | |
120 | ||
121 | parser.add_argument("-I", | |
122 | action="append", | |
123 | metavar="DIR", | |
124 | dest="includes", | |
125 | help="Add DIR to list of directories to include") | |
126 | ||
127 | parser.add_argument("-D", | |
128 | action="append", | |
129 | metavar="DEFINITION", | |
130 | dest="defines", | |
131 | help="Add DEFINITION to list of definitions") | |
132 | ||
133 | args = parser.parse_args() | |
134 | ||
135 | if args.ignores: | |
136 | with open(args.ignores, "r") as f: | |
137 | for ignore in f.read().splitlines(): | |
138 | forbiden_list.add(ignore) | |
139 | ||
140 | prologue_tpl = Template("""/* Auto-generated */ | |
141 | #include "lttng-ust-hsa-states.h" | |
142 | """) | |
143 | ||
144 | ret_fn_tpl = Template(""" | |
145 | static ${ret_type} lttng_${fn_name}(${fn_arguments}) | |
146 | { | |
147 | ${ret_type} ret; | |
148 | { | |
149 | lttng_hsa::api_object_${fn_name} lttng_api_object {${fn_rest_argument_names}}; | |
150 | ret = next_hsa_core_table.${fn_name}_fn(${fn_pass_argument_names}); | |
151 | lttng_api_object.mark_return(ret); | |
152 | } | |
153 | $extra_work | |
154 | return ret; | |
155 | } | |
156 | """) | |
157 | ||
158 | void_fn_tpl = Template(""" | |
159 | static void lttng_${fn_name}(${fn_arguments}) | |
160 | { | |
161 | { | |
162 | lttng_hsa::api_object_${fn_name} lttng_api_object {${fn_rest_argument_names}}; | |
163 | next_hsa_core_table.${fn_name}_fn(${fn_pass_argument_names}); | |
164 | lttng_api_object.mark_return(); | |
165 | } | |
166 | $extra_work | |
167 | } | |
168 | """) | |
169 | ||
170 | epilogue_tpl = Template(""" | |
171 | static void lttng_hsa_install_wrappers(void) | |
172 | { | |
173 | ${wrappers} | |
174 | } | |
175 | """) | |
176 | ||
177 | functions = list_functions(parse_header(args.api, | |
178 | args.includes, | |
179 | args.defines)) | |
180 | ||
181 | with open(args.wrappers, "w") as output: | |
182 | ||
183 | output.write(prologue_tpl.substitute()) | |
184 | ||
185 | for fn in functions: | |
186 | ||
187 | if fn.spelling in forbiden_list: | |
188 | continue | |
189 | ||
190 | args = list(fn.get_arguments()) | |
191 | fn_pass_argument_names = ", ".join([ | |
192 | f"{arg.spelling}" | |
193 | for arg in args | |
194 | ]) | |
195 | ||
196 | if args: | |
197 | fn_rest_argument_names = ", ".join([ | |
198 | "(%s)%s" % (cast(arg), arg.spelling) | |
199 | for arg in args | |
200 | ]) | |
201 | else: | |
202 | fn_rest_argument_names="" | |
203 | ||
204 | if fn.spelling in extra_works: | |
205 | extra_work = extra_works[fn.spelling] | |
206 | else: | |
207 | extra_work = "" | |
208 | ||
209 | if "void"== fn.type.get_result().spelling: | |
210 | fn_tpl = void_fn_tpl | |
211 | else: | |
212 | fn_tpl = ret_fn_tpl | |
213 | ||
214 | output.write(fn_tpl.substitute(fn_name=fn.spelling, | |
215 | fn_arguments=", ".join([ | |
216 | exact_definition(arg) | |
217 | for arg in fn.get_arguments() | |
218 | ]), | |
219 | fn_pass_argument_names=fn_pass_argument_names, | |
220 | fn_rest_argument_names=fn_rest_argument_names, | |
221 | ret_type=fn.type.get_result().spelling, | |
222 | extra_work=extra_work)) | |
223 | ||
224 | output.write(epilogue_tpl.substitute(wrappers="\n ".join([ | |
225 | f"lttng_hsa_core_table.{fn.spelling}_fn = <tng_{fn.spelling};" | |
226 | for fn in functions if fn.spelling not in forbiden_list | |
227 | ]))) | |
228 | ||
229 | ||
230 | if __name__ == "__main__": | |
231 | main() |