Make configure script more lenient to the length of CUDA and cuDNN versions entered. (#16853)
diff --git a/configure.py b/configure.py
index 151ad5d..5e48173 100644
--- a/configure.py
+++ b/configure.py
@@ -827,6 +827,28 @@
write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path)
+def reformat_version_sequence(version_str, sequence_count):
+ """Reformat the version string to have the given number of sequences.
+
+ For example:
+ Given (7, 2) -> 7.0
+ (7.0.1, 2) -> 7.0
+ (5, 1) -> 5
+ (5.0.3.2, 1) -> 5
+
+ Args:
+ version_str: String, the version string.
+ sequence_count: int, an integer.
+ Returns:
+ string, reformatted version string.
+ """
+ v = version_str.split('.')
+ if len(v) < sequence_count:
+ v = v + (['0'] * (sequence_count - len(v)))
+
+ return '.'.join(v[:sequence_count])
+
+
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
ask_cuda_version = (
@@ -837,6 +859,7 @@
# Configure the Cuda SDK version to use.
tf_cuda_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION)
+ tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2)
# Find out where the CUDA toolkit is installed
default_cuda_path = _DEFAULT_CUDA_PATH
@@ -893,6 +916,7 @@
tf_cudnn_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version,
_DEFAULT_CUDNN_VERSION)
+ tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version) ,1)
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '