|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import argparse
- import itertools
- import os
-
- from gen_elemwise_multi_type_utils import ( # isort: skip; isort: skip
- MODES,
- QINT32_MODES,
- SUPPORT_DTYPES,
- SUPPORT_QINT32_DTYPES,
- )
-
-
- def generate(modes, support_dtypes, output, cpp_ext):
- for anum, ctype in itertools.product(modes.keys(), support_dtypes):
- print("{} : {}".format(anum, ctype))
- src_ctype = ctype[0]
- dst_ctype = ctype[1]
- for mode in modes[anum]:
- formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
- fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext)
- fname = os.path.join(output, fname)
- with open(fname, "w") as fout:
- w = lambda s: print(s, file=fout)
- w("// generated by gen_elemwise_multi_type_kern_impls.py")
-
- w("#define KERN_IMPL_MODE(cb) {}".format(formode))
- w("#define KERN_IMPL_ARITY {}".format(anum))
- w("#define KERN_IMPL_STYPE {}".format(src_ctype))
- w("#define KERN_IMPL_DTYPE {}".format(dst_ctype))
- w('#include "../kern_impl.inl"')
-
- print("generated {}".format(fname))
-
-
- def main():
- parser = argparse.ArgumentParser(
- description="generate elemwise impl files",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--type",
- type=str,
- choices=["cuda"],
- default="cuda",
- help="generate cuda kernel file",
- )
- parser.add_argument("output", help="output directory")
- args = parser.parse_args()
-
- if not os.path.isdir(args.output):
- os.makedirs(args.output)
-
- assert args.type == "cuda"
- if args.type == "cuda":
- cpp_ext = "cu"
-
- generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext)
- generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext)
- os.utime(args.output)
-
-
- if __name__ == "__main__":
- main()
|