Switch from TF_TENSORRT_VERSION to TF_NEED_TENSORRT environment variable to determine whether to build with TF-TRT.
Store fallback env variables for CUDA configs in bazelrc as well.
PiperOrigin-RevId: 245961133
diff --git a/configure.py b/configure.py
index b4d7778..8d6c6e5 100644
--- a/configure.py
+++ b/configure.py
@@ -1324,9 +1324,9 @@
cuda_libraries = ['cuda', 'cudnn']
if is_linux():
- if 'TF_TENSORRT_VERSION' in environ_cp: # if env variable exists
+ if environ_cp.get('TF_NEED_TENSORRT', None):
cuda_libraries.append('tensorrt')
- if environ_cp.get('TF_NCCL_VERSION', None): # if env variable not empty
+ if environ_cp.get('TF_NCCL_VERSION', None):
cuda_libraries.append('nccl')
proc = subprocess.Popen(
@@ -1453,8 +1453,12 @@
cuda_env_names = [
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
- 'CUDA_TOOLKIT_PATH'
+ # Items below are for backwards compatibility when not using
+ # TF_CUDA_PATHS.
+ 'CUDA_TOOLKIT_PATH', 'CUDNN_INSTALL_PATH', 'NCCL_INSTALL_PATH',
+ 'NCCL_HDR_PATH', 'TENSORRT_INSTALL_PATH'
]
+ # Note: set_action_env_var above already writes to bazelrc.
for name in cuda_env_names:
if name in environ_cp:
write_action_env_to_bazelrc(name, environ_cp[name])
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 607210b..380e642 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -335,10 +335,8 @@
return "\n".join(inc_entries)
def enable_cuda(repository_ctx):
- if "TF_NEED_CUDA" in repository_ctx.os.environ:
- enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
- return enable_cuda == "1"
- return False
+ """Returns whether to build with CUDA support."""
+ return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False))
def matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version.
diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl
index 5c2dded..728c318 100644
--- a/third_party/tensorrt/tensorrt_configure.bzl
+++ b/third_party/tensorrt/tensorrt_configure.bzl
@@ -18,6 +18,7 @@
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
+_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
@@ -43,6 +44,10 @@
"%{tensorrt_libs}": "[]",
})
+def enable_tensorrt(repository_ctx):
+ """Returns whether to build with TensorRT support."""
+ return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False))
+
def _tensorrt_configure_impl(repository_ctx):
"""Implementation of the tensorrt_configure repository rule."""
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
@@ -56,7 +61,7 @@
)
return
- if _TF_TENSORRT_VERSION not in repository_ctx.os.environ:
+ if not enable_tensorrt(repository_ctx):
_create_dummy_repository(repository_ctx)
return
@@ -99,6 +104,7 @@
_TENSORRT_INSTALL_PATH,
_TF_TENSORRT_VERSION,
_TF_TENSORRT_CONFIG_REPO,
+ _TF_NEED_TENSORRT,
"TF_CUDA_PATHS",
],
)