|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import os
- import argparse
- import itertools
- from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES
-
- def generate(modes, support_dtypes, output, cpp_ext):
- for anum, ctype in itertools.product(modes.keys(), support_dtypes):
- print('{} : {}'.format(anum, ctype))
- src_ctype = ctype[0]
- dst_ctype = ctype[1]
- for mode in modes[anum]:
- formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
- fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext)
- fname = os.path.join(output, fname)
- with open(fname, 'w') as fout:
- w = lambda s: print(s, file=fout)
- w('// generated by gen_elemwise_multi_type_kern_impls.py')
-
- w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
- w('#define KERN_IMPL_ARITY {}'.format(anum))
- w('#define KERN_IMPL_STYPE {}'.format(src_ctype))
- w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype))
- w('#include "../kern_impl.inl"')
-
- print('generated {}'.format(fname))
-
-
- def main():
- parser = argparse.ArgumentParser(
- description='generate elemwise impl files',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--type', type=str, choices=['cuda'],
- default='cuda', help='generate cuda kernel file')
- parser.add_argument('output', help='output directory')
- args = parser.parse_args()
-
- if not os.path.isdir(args.output):
- os.makedirs(args.output)
-
- assert args.type == 'cuda'
- if args.type == 'cuda':
- cpp_ext = 'cu'
-
- generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext)
- generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext)
- os.utime(args.output)
-
- if __name__ == '__main__':
- main()
|