diff --git a/src/reduce.cu b/src/reduce.cu index df7b3af..eb38ddd 100644 --- a/src/reduce.cu +++ b/src/reduce.cu @@ -41,7 +41,11 @@ struct CTAReduce { T shuff; for (int offset = warp_size / 2; offset > 0; offset /= 2) { +#if CUDART_VERSION < 9000 shuff = __shfl_down(x, offset); +#else + shuff = __shfl_down_sync(0xFFFFFFFF, x, offset); +#endif if (tid + offset < count && tid < offset) x = g(x, shuff); } diff --git a/tensorflow_binding/setup.py b/tensorflow_binding/setup.py index a6682c0..e3404db 100644 --- a/tensorflow_binding/setup.py +++ b/tensorflow_binding/setup.py @@ -10,6 +10,7 @@ import unittest import warnings from setuptools.command.build_ext import build_ext as orig_build_ext +from distutils.version import LooseVersion # We need to import tensorflow to find where its include directory is. try: @@ -27,13 +28,6 @@ else: enable_gpu = True - -if "TENSORFLOW_SRC_PATH" not in os.environ: - print("Please define the TENSORFLOW_SRC_PATH environment variable.\n" - "This should be a path to the Tensorflow source directory.", - file=sys.stderr) - sys.exit(1) - if platform.system() == 'Darwin': lib_ext = ".dylib" else: @@ -52,13 +46,16 @@ root_path = os.path.realpath(os.path.dirname(__file__)) tf_include = tf.sysconfig.get_include() -tf_src_dir = os.environ["TENSORFLOW_SRC_PATH"] +tf_src_dir = tf.sysconfig.get_lib() # os.environ["TENSORFLOW_SRC_PATH"] tf_includes = [tf_include, tf_src_dir] warp_ctc_includes = [os.path.join(root_path, '../include')] include_dirs = tf_includes + warp_ctc_includes -if tf.__version__ >= '1.4': - include_dirs += [tf_include + '/../../external/nsync/public'] +if LooseVersion(tf.__version__) >= LooseVersion('1.4'): + nsync_dir = '../../external/nsync/public' + if LooseVersion(tf.__version__) >= LooseVersion('1.10'): + nsync_dir = 'external/nsync/public' + include_dirs += [os.path.join(tf_include, nsync_dir)] if os.getenv("TF_CXX11_ABI") is not None: TF_CXX11_ABI = os.getenv("TF_CXX11_ABI") @@ -78,9 +75,9 @@ extra_compile_args += ['-Wno-return-type'] extra_link_args = [] -if tf.__version__ >= '1.4': +if LooseVersion(tf.__version__) >= LooseVersion('1.4'): if os.path.exists(os.path.join(tf_src_dir, 'libtensorflow_framework.so')): - extra_link_args = ['-L' + tf.sysconfig.get_lib(), '-ltensorflow_framework'] + extra_link_args = ['-L' + tf_src_dir, '-ltensorflow_framework'] if (enable_gpu): extra_compile_args += ['-DWARPCTC_ENABLE_GPU']