开源脉冲神经网络深度学习框架 https://spikingjelly.readthedocs.io
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

79 lines
2.4 KiB

  1. '''
  2. python setup.py sdist bdist_wheel
  3. python -m twine upload dist/*
  4. '''
  5. import setuptools
  6. import glob
  7. import os
  8. import torch
  9. from setuptools import find_packages
  10. from setuptools import setup
  11. from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension, BuildExtension
  12. import sys
  13. requirements = ["torch"]
  14. def get_extensions():
  15. if CUDA_HOME is None:
  16. print('CUDA_HOME is None. Install Without CUDA Extension')
  17. return None
  18. else:
  19. print('Install With CUDA Extension')
  20. this_dir = os.path.dirname(os.path.abspath(__file__))
  21. extensions_dir = os.path.join(this_dir, 'spikingjelly', 'cext', 'csrc')
  22. if sys.platform == 'win32' or int(torch.version.cuda.split('.')[0]) < 11:
  23. # windows and cuda<11 does not support cuSparse
  24. ext_list = ['neuron']
  25. else:
  26. ext_list = ['gemm', 'neuron']
  27. extra_compile_args = {'cxx': ['-g'], 'nvcc': ['-use_fast_math']}
  28. extension = CUDAExtension
  29. define_macros = [("WITH_CUDA", None)]
  30. ext_modules = list([
  31. extension(
  32. '_C_' + ext_name,
  33. glob.glob(os.path.join(extensions_dir, ext_name, '*.cpp')) + glob.glob(os.path.join(extensions_dir, ext_name, '*.cu')),
  34. define_macros=define_macros,
  35. extra_compile_args=extra_compile_args
  36. ) for ext_name in ext_list])
  37. return ext_modules
  38. with open("./requirements.txt", "r", encoding="utf-8") as fh:
  39. install_requires = fh.read()
  40. with open("./README.md", "r", encoding="utf-8") as fh:
  41. long_description = fh.read()
  42. setup(
  43. install_requires=install_requires,
  44. name="spikingjelly",
  45. version="0.0.0.0.5",
  46. author="PKU MLG, PCL, and other contributors",
  47. author_email="fwei@pku.edu.cn, chyq@pku.edu.cn",
  48. description="A deep learning framework for SNNs built on PyTorch.",
  49. long_description=long_description,
  50. long_description_content_type="text/markdown",
  51. url="https://github.com/fangwei123456/spikingjelly",
  52. packages=find_packages(),
  53. classifiers=[
  54. "Programming Language :: Python :: 3 :: Only",
  55. "Programming Language :: Python :: 3.6",
  56. "Programming Language :: Python :: 3.7",
  57. "Programming Language :: Python :: 3.8",
  58. "License :: Other/Proprietary License",
  59. "Operating System :: OS Independent",
  60. ],
  61. python_requires='>=3.6',
  62. ext_modules=get_extensions(),
  63. cmdclass={
  64. "build_ext": BuildExtension
  65. }
  66. )