Merge pull request #22519 from jayfurmanek:nccl2_configure

PiperOrigin-RevId: 215310536
diff --git a/configure.py b/configure.py
index 0efa11a..7e47175 100644
--- a/configure.py
+++ b/configure.py
@@ -52,6 +52,10 @@
 _TF_WORKSPACE_ROOT = ''
 _TF_BAZELRC = ''
 
+NCCL_LIB_PATHS = [
+    'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
+]
+
 if platform.machine() == 'ppc64le':
   _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
 else:
@@ -1097,7 +1101,7 @@
 
 
 def set_tf_nccl_install_path(environ_cp):
-  """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION.
+  """Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION.
 
   Args:
     environ_cp: copy of the os.environ.
@@ -1123,46 +1127,107 @@
     if tf_nccl_version == '1':
       break  # No need to get install path, NCCL 1 is a GitHub repo.
 
-    # TODO(csigg): Look with ldconfig first if we can find the library in paths
+    # Look with ldconfig first if we can find the library in paths
     # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
     # include directory. This is where the NCCL .deb packages install them.
-    # Then ask the user if we should use that. Instead of a single
-    # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to
-    # nccl_configure.bzl
-    default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
-    ask_nccl_path = (r'Please specify the location where NCCL %s library is '
-                     'installed. Refer to README.md for more details. [Default '
-                     'is %s]:') % (tf_nccl_version, default_nccl_path)
-    nccl_install_path = get_from_env_or_user_or_default(
-        environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
 
-    # Result returned from "read" will be used unexpanded. That make "~"
-    # unusable. Going through one more level of expansion to handle that.
-    nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path))
-    if is_windows() or is_cygwin():
-      nccl_install_path = cygpath(nccl_install_path)
+    # First check to see if NCCL is in the ldconfig.
+    # If its found, use that location.
+    if is_linux():
+      ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
+      nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
+      nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)',
+                                           nccl2_path_from_ldconfig)
+    if nccl2_path_from_ldconfig:
+      nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1)
+      if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)):
+        nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig)
+        print('NCCL libraries found in ' + nccl2_path_from_ldconfig)
 
-    if is_windows():
-      nccl_lib_path = 'lib/x64/nccl.lib'
-    elif is_linux():
-      nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version
-    elif is_macos():
-      nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
+        # Check if this is the main system lib location
+        if re.search('.*linux-gnu', nccl_install_path):
+          trunc_nccl_install_path = '/usr'
+          print('This looks like a system path.')
+        else:
+          trunc_nccl_install_path = nccl_install_path + '/..'
 
-    nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
-    nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h')
-    if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
-      # Set NCCL_INSTALL_PATH
-      environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
-      write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
-      break
+        # Look for header
+        nccl_hdr_path = trunc_nccl_install_path + '/include'
+        print('Assuming NCCL header path is ' + nccl_hdr_path)
+        if os.path.exists(nccl_hdr_path + '/nccl.h'):
+          # Set NCCL_INSTALL_PATH
+          environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
+          write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
 
-    # Reset and Retry
-    print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
+          # Set NCCL_HDR_PATH
+          environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path
+          write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path)
+          break
+        else:
+          print(
+              'The header for NCCL2 cannot be found. Please install the libnccl-dev package.'
+          )
+      else:
+        print('NCCL2 is listed by ldconfig but the library is not found. '
+              'Your ldconfig is out of date. Please run sudo ldconfig.')
+    else:
+      # NCCL is not found in ldconfig. Ask the user for the location.
+      default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
+      ask_nccl_path = (
+          r'Please specify the location where NCCL %s library is '
+          'installed. Refer to README.md for more details. [Default '
+          'is %s]:') % (tf_nccl_version, default_nccl_path)
+      nccl_install_path = get_from_env_or_user_or_default(
+          environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
+
+      # Result returned from "read" will be used unexpanded. That make "~"
+      # unusable. Going through one more level of expansion to handle that.
+      nccl_install_path = os.path.realpath(
+          os.path.expanduser(nccl_install_path))
+      if is_windows() or is_cygwin():
+        nccl_install_path = cygpath(nccl_install_path)
+
+      if is_windows():
+        nccl_lib_path = 'lib/x64/nccl.lib'
+      elif is_linux():
+        nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version
+        nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename)
+        if not os.path.exists(nccl_lpath):
+          for relative_path in NCCL_LIB_PATHS:
+            path = '%s/%s%s' % (nccl_install_path, relative_path,
+                                nccl_lib_filename)
+            if os.path.exists(path):
+              print('NCCL found at ' + path)
+              nccl_lib_path = path
+              break
+        else:
+          nccl_lib_path = nccl_lpath
+      elif is_macos():
+        nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
+
+      nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
+      nccl_hdr_path = os.path.join(
+          os.path.dirname(nccl_lib_path), '../include/nccl.h')
+      print('Assuming NCCL header path is ' + nccl_hdr_path)
+      if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
+        # Set NCCL_INSTALL_PATH
+        environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path)
+        write_action_env_to_bazelrc('NCCL_INSTALL_PATH',
+                                    os.path.dirname(nccl_lib_path))
+
+        # Set NCCL_HDR_PATH
+        environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path)
+        write_action_env_to_bazelrc('NCCL_HDR_PATH',
+                                    os.path.dirname(nccl_hdr_path))
+        break
+
+      # Reset and Retry
+      print(
+          'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
           'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path,
                                               nccl_hdr_path))
 
-    environ_cp['TF_NCCL_VERSION'] = ''
+      environ_cp['TF_NCCL_VERSION'] = ''
   else:
     raise UserInputError('Invalid TF_NCCL setting was provided %d '
                          'times in a row. Assuming to be a scripting mistake.' %
diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl
index ce94470..d78fe8f 100644
--- a/third_party/nccl/nccl_configure.bzl
+++ b/third_party/nccl/nccl_configure.bzl
@@ -5,6 +5,7 @@
 
   * `TF_NCCL_VERSION`: The NCCL version.
   * `NCCL_INSTALL_PATH`: The installation path of the NCCL library.
+  * `NCCL_HDR_PATH`: The installation path of the NCCL header files.
 """
 
 load(
@@ -15,6 +16,7 @@
 )
 
 _NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
+_NCCL_HDR_PATH = "NCCL_HDR_PATH"
 _TF_NCCL_VERSION = "TF_NCCL_VERSION"
 _TF_NCCL_CONFIG_REPO = "TF_NCCL_CONFIG_REPO"
 
@@ -68,7 +70,7 @@
   return header_path
 
 
-def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version):
+def _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version):
   """Checks whether the header file matches the specified version of NCCL.
 
   Args:
@@ -79,7 +81,9 @@
   Returns:
     A string containing the library version of NCCL.
   """
-  header_path = _find_nccl_header(repository_ctx, nccl_install_path)
+  header_path = repository_ctx.path("%s/nccl.h" % nccl_hdr_path)
+  if not header_path.exists:
+    header_path = _find_nccl_header(repository_ctx, nccl_install_path)
   header_dir = str(header_path.realpath.dirname)
   major_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
                                    _DEFINE_NCCL_MAJOR)
@@ -138,10 +142,12 @@
   else:
     # Create target for locally installed NCCL.
     nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
-    _check_nccl_version(repository_ctx, nccl_install_path, nccl_version)
+    nccl_hdr_path = repository_ctx.os.environ[_NCCL_HDR_PATH].strip()
+    _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version)
     repository_ctx.template("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE, {
         "%{version}": nccl_version,
         "%{install_path}": nccl_install_path,
+        "%{hdr_path}": nccl_hdr_path,
     })
 
 
@@ -149,6 +155,7 @@
     implementation=_nccl_configure_impl,
     environ=[
         _NCCL_INSTALL_PATH,
+        _NCCL_HDR_PATH,
         _TF_NCCL_VERSION,
     ],
 )
diff --git a/third_party/nccl/system.BUILD.tpl b/third_party/nccl/system.BUILD.tpl
index 7ca835d..a07f549 100644
--- a/third_party/nccl/system.BUILD.tpl
+++ b/third_party/nccl/system.BUILD.tpl
@@ -20,7 +20,7 @@
     "libnccl.so.%{version}",
     "nccl.h",
   ],
-  cmd = """cp "%{install_path}/include/nccl.h" "$(@D)/nccl.h" &&
-           cp "%{install_path}/lib/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
+  cmd = """cp "%{hdr_path}/nccl.h" "$(@D)/nccl.h" &&
+           cp "%{install_path}/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
 )