|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import argparse
- import itertools
- import os
-
- from gen_elemwise_utils import ARITIES, DTYPES, MODES
-
-
- def main():
- parser = argparse.ArgumentParser(
- description="generate elemwise impl files",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--type",
- type=str,
- choices=["cuda", "hip", "cpp"],
- default="cpp",
- help="generate cuda/hip kernel file",
- )
- parser.add_argument("output", help="output directory")
- args = parser.parse_args()
-
- if not os.path.isdir(args.output):
- os.makedirs(args.output)
-
- if args.type == "cuda":
- cpp_ext = "cu"
- elif args.type == "hip":
- cpp_ext = "cpp.hip"
- else:
- assert args.type == "cpp"
- cpp_ext = "cpp"
-
- for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()):
- for mode in MODES[(anum, DTYPES[ctype][1])]:
- formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
- fname = "{}_{}.{}".format(mode, ctype, cpp_ext)
- fname = os.path.join(args.output, fname)
- with open(fname, "w") as fout:
- w = lambda s: print(s, file=fout)
- w("// generated by gen_elemwise_kern_impls.py")
-
- if ctype == "dt_float16" or ctype == "dt_bfloat16":
- w("#if !MEGDNN_DISABLE_FLOAT16")
-
- w("#define KERN_IMPL_MODE(cb) {}".format(formode))
- w("#define KERN_IMPL_ARITY {}".format(anum))
- w("#define KERN_IMPL_CTYPE {}".format(ctype))
- w('#include "../kern_impl.inl"')
-
- if ctype == "dt_float16" or ctype == "dt_bfloat16":
- w("#endif")
-
- print("generated {}".format(fname))
-
- os.utime(args.output)
-
-
- if __name__ == "__main__":
- main()
|