Suggest default CUDA/cuDNN/TensorRT versions after initial auto-detection.
Reduce spew on error message when some file was not found.
PiperOrigin-RevId: 243810113
diff --git a/configure.py b/configure.py
index fe0e6d7..1a69b06 100644
--- a/configure.py
+++ b/configure.py
@@ -33,6 +33,9 @@
from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top
+_DEFAULT_CUDA_VERSION = '10'
+_DEFAULT_CUDNN_VERSION = '7'
+_DEFAULT_TENSORRT_VERSION = '5'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_TF_OPENCL_VERSION = '1.2'
@@ -865,21 +868,25 @@
def set_tf_cuda_version(environ_cp):
"""Set TF_CUDA_VERSION."""
- ask_cuda_version = ('Please specify the CUDA SDK version you want to use. '
- '[Leave empty to accept any version]: ')
+ ask_cuda_version = (
+ 'Please specify the CUDA SDK version you want to use. '
+ '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
tf_cuda_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDA_VERSION',
- ask_cuda_version, '')
+ ask_cuda_version,
+ _DEFAULT_CUDA_VERSION)
environ_cp['TF_CUDA_VERSION'] = tf_cuda_version
def set_tf_cudnn_version(environ_cp):
"""Set TF_CUDNN_VERSION."""
- ask_cudnn_version = ('Please specify the cuDNN version you want to use. '
- '[Leave empty to accept any version]: ')
+ ask_cudnn_version = (
+ 'Please specify the cuDNN version you want to use. '
+ '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION
tf_cudnn_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDNN_VERSION',
- ask_cudnn_version, '')
+ ask_cudnn_version,
+ _DEFAULT_CUDNN_VERSION)
environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version
@@ -922,11 +929,10 @@
ask_tensorrt_version = (
'Please specify the TensorRT version you want to use. '
- '[Leave empty to accept any version]: ')
- tf_tensorrt_version = get_from_env_or_user_or_default(environ_cp,
- 'TF_TENSORRT_VERSION',
- ask_tensorrt_version,
- '')
+ '[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION
+ tf_tensorrt_version = get_from_env_or_user_or_default(
+ environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version,
+ _DEFAULT_TENSORRT_VERSION)
environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
@@ -1331,7 +1337,7 @@
if proc.wait():
# Errors from find_cuda_config.py were sent to stderr.
- print('\n\nAsking for detailed CUDA configuration...\n')
+ print('Asking for detailed CUDA configuration...\n')
return False
config = dict(