Merge changes from github.

PiperOrigin-RevId: 167401527
diff --git a/README.md b/README.md
index 87c7b1b..5a0739a 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,15 @@
 
 People who are a little more adventurous can also try our nightly binaries:
 
+**Nightly pip packages**
+* We are pleased to announce that TensorFlow now offers nightly pip packages
+under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) project on pypi.
+Simply run `pip install tf-nightly` in a clean environment to install the nightly
+tensorflow  build. We currently only support CPU-only packages on Linux and Mac.
+GPU packages on all platforms and Windows CPU-only packages will arrive soon!
 
+
+**Individual whl files**
 * Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
 * Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.3.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
 * Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.3.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
diff --git a/RELEASE.md b/RELEASE.md
index d120f06..3d497db 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -46,7 +46,7 @@
   * Display feed values with the `print_feed` or `pf` command and clickable links in the curses UI.
   * Runtime profiler at the op level and the Python source line level with the `run -p` command.
 * Initial release of the statistical distribution library `tf.distributions`.
-* GPU kernels and speed improvements for for unary `tf.where` and `tf.nn.top_k`.
+* GPU kernels and speed improvements for unary `tf.where` and `tf.nn.top_k`.
 * Monotonic Attention wrappers added to `tf.contrib.seq2seq`.
 * Added `tf.contrib.signal`, a library for signal processing primitives.
 * Added `tf.contrib.resampler`, containing CPU and GPU ops for differentiable resampling of images.
diff --git a/WORKSPACE b/WORKSPACE
index 5e9b991..a0fe67b 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -2,11 +2,11 @@
 
 http_archive(
     name = "io_bazel_rules_closure",
-    sha256 = "bc41b80486413aaa551860fc37471dbc0666e1dbb5236fb6177cb83b0c105846",
-    strip_prefix = "rules_closure-dec425a4ff3faf09a56c85d082e4eed05d8ce38f",
+    sha256 = "25f5399f18d8bf9ce435f85c6bbf671ec4820bc4396b3022cc5dc4bc66303609",
+    strip_prefix = "rules_closure-0.4.2",
     urls = [
-        "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz",  # 2017-06-02
-        "https://github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz",
+        "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz",  # 2017-08-29
+        "https://github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz",
     ],
 )
 
diff --git a/configure.py b/configure.py
index 1a0f71e..186fdc9 100644
--- a/configure.py
+++ b/configure.py
@@ -143,7 +143,7 @@
 
 def cygpath(path):
   """Convert path from posix to windows."""
-  return run_shell(['cygpath', '-m', path])
+  return os.path.abspath(path).replace('\\', '/')
 
 
 def get_python_path(environ_cp, python_bin_path):
@@ -196,7 +196,7 @@
     environ_cp['PYTHON_BIN_PATH'] = ''
 
   # Convert python path to Windows style before checking lib and version
-  if is_cygwin():
+  if is_windows() or is_cygwin():
     python_bin_path = cygpath(python_bin_path)
 
   # Get PYTHON_LIB_PATH
@@ -219,7 +219,7 @@
   python_major_version = get_python_major_version(python_bin_path)
 
   # Convert python path to Windows style before writing into bazel.rc
-  if is_cygwin():
+  if is_windows() or is_cygwin():
     python_lib_path = cygpath(python_lib_path)
 
   # Set-up env variables used by python_configure.bzl
@@ -600,7 +600,7 @@
 
     # Find out where the CUDA toolkit is installed
     default_cuda_path = _DEFAULT_CUDA_PATH
-    if is_cygwin():
+    if is_windows() or is_cygwin():
       default_cuda_path = cygpath(
           environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
     elif is_linux():
@@ -660,7 +660,7 @@
     # unusable. Going through one more level of expansion to handle that.
     cudnn_install_path = os.path.realpath(
         os.path.expanduser(cudnn_install_path))
-    if is_cygwin():
+    if is_windows() or is_cygwin():
       cudnn_install_path = cygpath(cudnn_install_path)
 
     if is_windows():
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index 6fc73c3..ccb58e7 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -95,6 +95,14 @@
 }
 REGISTER_GRADIENT_OP("Selu", SeluGradHelper);
 
+Status L2LossGrad(const Scope& scope, const Operation& op,
+                  const std::vector<Output>& grad_inputs,
+                  std::vector<Output>* grad_outputs) {
+  grad_outputs->push_back(Mul(scope, op.input(0), grad_inputs[0]));
+  return scope.status();
+}
+REGISTER_GRADIENT_OP("L2Loss", L2LossGrad);
+
 Status BiasAddGradHelper(const Scope& scope, const Operation& op,
                          const std::vector<Output>& grad_inputs,
                          std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index f9a512f..affc1e1 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -122,6 +122,14 @@
   RunTest(x, x_init_value, y, shape);
 }
 
+TEST_F(NNGradTest, L2LossGrad) {
+  TensorShape x_shape({5, 2});
+  TensorShape y_shape({1});
+  auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+  auto y = L2Loss(scope_, x);
+  RunTest(x, x_shape, y, y_shape);
+}
+
 TEST_F(NNGradTest, BiasAddGradHelper) {
   TensorShape shape({4, 5});
   TensorShape bias_shape({5});
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 2a999f5..2e7765c 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -389,7 +389,7 @@
 
   // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
   // again after the standard optimization passes [http://b/13329423].
-  // TODO(jingyue): SROA may further expose more optimization opportunities, such
+  // TODO(jingyue): SROA may further expose more optimization opportunities such
   // as more precise alias analysis and more function inlining (SROA may change
   // the inlining cost of a function). For now, running SROA already emits good
   // enough code for the evaluated benchmarks. We may want to run more
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 6507a9a..15850bf 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -82,6 +82,7 @@
   set_tests_properties(${_AT_TARGET}
     PROPERTIES ENVIRONMENT "TEST_TMPDIR=${tempdir};TEST_SRCDIR=${testdir}"
   )
+  set_tests_properties(${_AT_TARGET} PROPERTIES TIMEOUT "600")
 
   foreach(datafile ${_AT_DATA})
     file(RELATIVE_PATH datafile_rel ${tensorflow_source_dir} ${datafile})
@@ -117,6 +118,7 @@
     if (_AT_DEPENDS)
       add_dependencies(${_AT_TARGET} ${_AT_DEPENDS})
     endif()
+    set_tests_properties(${sourcefile} PROPERTIES TIMEOUT "600")
   endforeach()
 endfunction(AddPythonTests)
 
diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD
index 645e364..bebcf07 100644
--- a/tensorflow/contrib/gdr/BUILD
+++ b/tensorflow/contrib/gdr/BUILD
@@ -62,6 +62,7 @@
     }),
     deps = [
         ":gdr_proto_cc",
+        "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:gpu_runtime",
         "//tensorflow/core:lib",
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.h b/tensorflow/contrib/gdr/gdr_memory_manager.h
index 7e9fe01..e0e2a3f 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.h
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.h
@@ -16,14 +16,9 @@
 #ifndef GDR_MEMORY_MANAGER_H_
 #define GDR_MEMORY_MANAGER_H_
 
+#include "google/protobuf/any.pb.h"
 #include "tensorflow/core/lib/core/status.h"
 
-namespace google {
-namespace protobuf {
-class Any;
-}
-}
-
 namespace tensorflow {
 
 class Device;
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index 674f5db..ea8d9e0c 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -115,6 +115,7 @@
                     'legacy_linear',
                     'legacy_relu',
                     'OPTIMIZER_CLS_NAMES',
+                    'OPTIMIZER_SUMMARIES',
                     'regression_target',
                     'SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY',
                     'summaries']
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index ac217f0..7eb410b 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -129,8 +129,9 @@
                `None` to use all trainable variables.
     name: The name for this operation is used to scope operations and summaries.
     summaries: List of internal quantities to visualize on tensorboard. If not
-               set only the loss and the learning rate will be reported. The
-               complete list is in OPTIMIZER_SUMMARIES.
+               set, the loss, the learning rate, and the global norm of the
+               gradients will be reported. The complete list of possible values
+               is in OPTIMIZER_SUMMARIES.
     colocate_gradients_with_ops: If True, try colocating gradients with the
                                  corresponding op.
     increment_global_step: Whether to increment `global_step`. If your model
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
index 48d79ec..4c50d40 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py
@@ -28,7 +28,6 @@
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import tf_logging as logging
 
@@ -44,7 +43,7 @@
   x_is_dict, y_is_dict = isinstance(
       x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
   if y_is_dict and n_classes is not None:
-    assert (isinstance(n_classes, dict))
+    assert isinstance(n_classes, dict)
 
   if batch_size is None:
     batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
@@ -322,10 +321,12 @@
 
     self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
                    ]) if x_is_dict else check_array(x, x.dtype)
-    self._y = None if y is None else \
-      dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if x_is_dict else check_array(y, y.dtype)
+    self._y = None if y is None else (
+        dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())])
+        if y_is_dict else check_array(y, y.dtype))
 
-    # self.n_classes is not None means we're converting raw target indices to one-hot.
+    # self.n_classes is not None means we're converting raw target indices
+    # to one-hot.
     if n_classes is not None:
       if not y_is_dict:
         y_dtype = (np.int64
@@ -344,12 +345,15 @@
         x_shape, y_shape, n_classes, batch_size)
 
     # Input dtype matches dtype of x.
-    self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \
-      else _check_dtype(self._x.dtype)
+    self._input_dtype = (
+        dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())])
+        if x_is_dict else _check_dtype(self._x.dtype))
 
-    # note: self._output_dtype = np.float32 when y is None
-    self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \
-      else _check_dtype(self._y.dtype) if y is not None else np.float32
+    # self._output_dtype == np.float32 when y is None
+    self._output_dtype = (
+        dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())])
+        if y_is_dict else (
+            _check_dtype(self._y.dtype) if y is not None else np.float32))
 
     # self.n_classes is None means we're passing in raw target indices
     if n_classes is not None and y_is_dict:
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index 98af47d..30897bb 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -316,14 +316,14 @@
 	IPHONESIMULATOR_SYSROOT := $(shell xcrun --sdk iphonesimulator \
 	--show-sdk-path)
 	IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version)
-	MIN_SDK_VERSION := 8.0
+	MIN_SDK_VERSION := 9.0
 # Override IOS_ARCH with ARMV7, ARMV7S, ARM64, or I386.
 	IOS_ARCH := X86_64
 	ifeq ($(IOS_ARCH),ARMV7)
 		CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
 		-arch armv7 \
 		-fembed-bitcode \
-		-D__thread= \
+		-D__thread=thread_local \
 		-DUSE_GEMM_FOR_CONV \
 		-Wno-c++11-narrowing \
 		-mno-thumb \
@@ -347,7 +347,7 @@
 		CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
 		-arch armv7s \
 		-fembed-bitcode \
-		-D__thread= \
+		-D__thread=thread_local \
 		-DUSE_GEMM_FOR_CONV \
 		-Wno-c++11-narrowing \
 		-mno-thumb \
@@ -371,7 +371,7 @@
 		CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
 		-arch arm64 \
 		-fembed-bitcode \
-		-D__thread= \
+		-D__thread=thread_local \
 		-DUSE_GEMM_FOR_CONV \
 		-Wno-c++11-narrowing \
 		-DTF_LEAN_BINARY \
@@ -395,7 +395,7 @@
 		-arch i386 \
 		-mno-sse \
 		-fembed-bitcode \
-		-D__thread= \
+		-D__thread=thread_local \
 		-DUSE_GEMM_FOR_CONV \
 		-Wno-c++11-narrowing \
 		-DTF_LEAN_BINARY \
@@ -418,7 +418,7 @@
 		CXXFLAGS += -mios-simulator-version-min=$(MIN_SDK_VERSION) \
 		-arch x86_64 \
 		-fembed-bitcode \
-		-D__thread= \
+		-D__thread=thread_local \
 		-DUSE_GEMM_FOR_CONV \
 		-Wno-c++11-narrowing \
 		-DTF_LEAN_BINARY \
diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md
index 835d684..715eb51 100644
--- a/tensorflow/contrib/makefile/README.md
+++ b/tensorflow/contrib/makefile/README.md
@@ -201,7 +201,8 @@
 
 Then, you will need to compile the nsync library for iOS:
 
-```export HOST_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh`
+```bash
+export HOST_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh`
 export TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios`
 ```
 
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index cb2fde7..f91e377 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -35,14 +35,18 @@
 
 namespace tensorflow {
 
+static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
+  rendez->StartAbort(s);
+  rendez->Unref();
+}
+
 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
     : worker_env_(worker_env) {}
 
 BaseRendezvousMgr::~BaseRendezvousMgr() {
   for (auto& p : table_) {
-    BaseRemoteRendezvous* rendez = p.second;
-    rendez->StartAbort(errors::Aborted("Shutdown"));
-    rendez->Unref();
+    auto rendez = p.second;
+    StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
   }
 }
 
@@ -52,7 +56,7 @@
 
 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
   mutex_lock l(mu_);
-  Table::iterator iter = table_.find(step_id);
+  auto iter = table_.find(step_id);
   if (iter == table_.end()) {
     auto rr = Create(step_id, worker_env_);
     iter = table_.insert({step_id, rr}).first;
@@ -64,7 +68,7 @@
 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
                                        const Rendezvous::ParsedKey& parsed,
                                        Rendezvous::DoneCallback done) {
-  BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
+  auto rendez = FindOrCreate(step_id);
   using namespace std::placeholders;
   Rendezvous::DoneCallback done_cb = std::bind(
       [rendez](Rendezvous::DoneCallback done,
@@ -101,15 +105,15 @@
   Rendezvous* rendez = nullptr;
   {
     mutex_lock l(mu_);
-    Table::iterator iter = table_.find(step_id);
+    auto iter = table_.find(step_id);
     if (iter != table_.end()) {
       rendez = iter->second;
       table_.erase(iter);
     }
   }
-  if (!rendez) return;
-  rendez->StartAbort(errors::Aborted("Cleanup ", step_id));
-  rendez->Unref();
+  if (rendez) {
+    StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
+  }
 }
 
 void BaseRendezvousMgr::CleanupAll() {
@@ -122,8 +126,7 @@
     table_.clear();
   }
   for (auto rendez : rendezs) {
-    rendez->StartAbort(errors::Aborted("Shutdown"));
-    rendez->Unref();
+    StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
   }
 }
 
@@ -165,7 +168,7 @@
     session_ = session;
     std::swap(deferred_calls, deferred_calls_);
   }
-  for (DeferredCall& call : deferred_calls) {
+  for (auto& call : deferred_calls) {
     RecvLocalAsyncInternal(call.parsed, std::move(call.done));
   }
   return Status::OK();
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 46c21dc..25b35a6 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -310,7 +310,7 @@
   FunctionLibraryRuntime* function_library() const { return flib_; }
 
   // The GraphDef version whose behavior we should follow.
-  const int graph_def_version() const { return graph_def_version_; }
+  int graph_def_version() const { return graph_def_version_; }
 
   // Helper routines for the OP_REQUIRES macros
   void CtxFailure(Status s);
diff --git a/tensorflow/core/profiler/g3doc/advise.md b/tensorflow/core/profiler/g3doc/advise.md
index d87b0d8..d0de831 100644
--- a/tensorflow/core/profiler/g3doc/advise.md
+++ b/tensorflow/core/profiler/g3doc/advise.md
@@ -86,7 +86,7 @@
 *   Checks RecvTensor RPC latency and bandwidth.
 *   Checks CPU/Memory utilization of the job.
 
-####AcceleratorUtilization Checker
+#### AcceleratorUtilization Checker
 * Checks what percentage of time the accelerator spends on computation.
 
 #### OperationChecker
@@ -100,7 +100,7 @@
 *   Checks the most expensive graph nodes.
 *   Checks the most expensive graph-building Python codes.
 
-####Contribute Your Checker
+#### Contribute Your Checker
 
 Follow examples of accelerator_utilization_checker.h
 
diff --git a/tensorflow/core/profiler/g3doc/command_line.md b/tensorflow/core/profiler/g3doc/command_line.md
index 857b5e6..e2839a6 100644
--- a/tensorflow/core/profiler/g3doc/command_line.md
+++ b/tensorflow/core/profiler/g3doc/command_line.md
@@ -51,7 +51,7 @@
 Note: this feature is not well maintained now.
 
 
-###Start `tfprof`
+### Start `tfprof`
 
 #### Build `tfprof`
 
@@ -140,9 +140,9 @@
 -output
 ```
 
-###Examples
+### Examples
 
-####Profile Python Time
+#### Profile Python Time
 ```shell
 # Requires --graph_path --op_log_path
 tfprof> code -max_depth 1000 -show_name_regexes .*model_analyzer.*py.* -select micros -account_type_regexes .* -order_by micros
diff --git a/tensorflow/core/profiler/g3doc/options.md b/tensorflow/core/profiler/g3doc/options.md
index 15712d0..ddee63a 100644
--- a/tensorflow/core/profiler/g3doc/options.md
+++ b/tensorflow/core/profiler/g3doc/options.md
@@ -1,6 +1,6 @@
-##Options
+## Options
 
-###Overview
+### Overview
 
 For all tfprof views, the profiles are processed with the following procedures
 
@@ -35,7 +35,7 @@
 4) Finally, the filtered data structure is output in a format depending
    on the `-output` option.
 
-####Option Semantics In Different View
+#### Option Semantics In Different View
 options usually have the same semantics in different views. However, some
 can vary. For example `-max_depth` in scope view means the depth of
 name scope <b>tree</b>. In op view, it means the length of operation <b>list</b>.
@@ -68,7 +68,7 @@
               by the current operation. For example, it can be a tensor
               forwarded from input to output, with in-place mutation.
 
-###Docs
+### Docs
 
 `-max_depth`: Show nodes that are at most this number of hops from starting node in the data structure.
 
diff --git a/tensorflow/core/profiler/g3doc/profile_memory.md b/tensorflow/core/profiler/g3doc/profile_memory.md
index a00683d..6eda5ab 100644
--- a/tensorflow/core/profiler/g3doc/profile_memory.md
+++ b/tensorflow/core/profiler/g3doc/profile_memory.md
@@ -1,4 +1,4 @@
-##Profile Memory
+## Profile Memory
 
 It is generally a good idea to visualize the memory usage in timeline.
 It allows you to see the memory consumption of each GPU over time.
diff --git a/tensorflow/core/profiler/g3doc/profile_model_architecture.md b/tensorflow/core/profiler/g3doc/profile_model_architecture.md
index a42b2e9..61bb66b 100644
--- a/tensorflow/core/profiler/g3doc/profile_model_architecture.md
+++ b/tensorflow/core/profiler/g3doc/profile_model_architecture.md
@@ -1,9 +1,9 @@
-##Profile Model Architecture
+## Profile Model Architecture
 
 * [Profile Model Parameters](#profile-model-parameters)
 * [Profile Model Float Operations](#profile-model-float-operations)
 
-###Profile Model Parameters
+### Profile Model Parameters
 
 <b>Notes:</b>
 `VariableV2` operation type might contain variables created by TensorFlow
@@ -39,9 +39,9 @@
 sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
 ```
 
-###Profile Model Float Operations
+### Profile Model Float Operations
 
-####Caveats
+#### Caveats
 
 For an operation to have float operation statistics:
 
diff --git a/tensorflow/core/profiler/g3doc/profile_time.md b/tensorflow/core/profiler/g3doc/profile_time.md
index e11a755..4aafc69 100644
--- a/tensorflow/core/profiler/g3doc/profile_time.md
+++ b/tensorflow/core/profiler/g3doc/profile_time.md
@@ -1,4 +1,4 @@
-##Profile Time
+## Profile Time
 
 * [Times in TensorFlow and tfprof](#times-in-tensorflow-and-tfprof)
 * [Profile by Python Code](#profile-by-python-code)
@@ -7,7 +7,7 @@
 * [Profile by Name Scope](#profile-by-name-scope)
 
 
-###Times in TensorFlow and tfprof
+### Times in TensorFlow and tfprof
 When we run a model, Tensorflow schedules and runs the nodes (operations)
 in the graph. An operation can be placed on an accelerator or on CPU.
 
@@ -37,7 +37,7 @@
 should be 0.
 
 
-###Profile by Python Code
+### Profile by Python Code
 ```python
 # In code view, the time of each line of Python code is the aggregated
 # times of all operations created by that line.
@@ -112,7 +112,7 @@
 </left>
 
 
-###Profile by Operation Type
+### Profile by Operation Type
 ```python
 # In op view, you can view the aggregated time of each operation type.
 tfprof> op -select micros,occurrence -order_by micros
@@ -138,7 +138,7 @@
 ```
 
 
-###Profile by Graph
+### Profile by Graph
 
 Usually, use graph view to generate a timeline to visualize the result.
 
@@ -163,7 +163,7 @@
 ******************************************************
 ```
 
-###Profile by Name Scope
+### Profile by Name Scope
 
 Usually scope view allows you to pin point the problematic places if you
 have properly named your operations with tf.name_scope or tf.variable_scope.
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 43e0990..d5e4815 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -151,10 +151,10 @@
      (tensorflow)$ <b>pip install --upgrade tensorflow-gpu</b>  # for Python 2.7 and GPU
      (tensorflow)$ <b>pip3 install --upgrade tensorflow-gpu</b> # for Python 3.n and GPU</pre>
 
-     If the preceding command succeeds, skip Step 5. If the preceding
-     command fails, perform Step 5.
+     If the preceding command succeeds, skip Step 6. If the preceding
+     command fails, perform Step 6.
 
-  5. (Optional) If Step 4 failed (typically because you invoked a pip version
+  6. (Optional) If Step 5 failed (typically because you invoked a pip version
      lower than 8.1), install TensorFlow in the active virtualenv environment
      by issuing a command of the following format:
 
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index be6a490..3025c99 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -71,12 +71,14 @@
 
 ## Installing with native pip
 
-If the following version of Python is not installed on your machine,
+If one of the following versions of Python is not installed on your machine,
 install it now:
 
   * [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/)
+  * [Python 3.6.x 64-bit from python.org](https://www.python.org/downloads/release/python-362/)
 
-Note that Python 3.5.x comes with the pip3 package manager, which is the
+-TensorFlow supports Python 3.5.x and 3.6.x on Windows.
+Note that Python 3 comes with the pip3 package manager, which is the
 program you'll use to install TensorFlow.
 
 To install TensorFlow, start a terminal. Then issue the appropriate
diff --git a/tensorflow/examples/speech_commands/README.md b/tensorflow/examples/speech_commands/README.md
index 3b78210..63be04e 100644
--- a/tensorflow/examples/speech_commands/README.md
+++ b/tensorflow/examples/speech_commands/README.md
@@ -1,4 +1,4 @@
 # Speech Commands Example
 
 This is a basic speech recognition example. For more information, see the
-tutorial at http://tensorflow.org/tutorials/audio_recognition.
+tutorial at https://www.tensorflow.org/versions/master/tutorials/audio_recognition.
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 44ab1a6..a8434d0 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2473,7 +2473,7 @@
       weighted_column = sparse_ops.sparse_merge(
           sp_ids=id_tensor,
           sp_values=weight_tensor,
-          vocab_size=self._variable_shape[-1])
+          vocab_size=int(self._variable_shape[-1]))
       return sparse_ops.sparse_tensor_to_dense(weighted_column)
 
     dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index b14ec73..3057776 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -3206,6 +3206,20 @@
     with _initialized_session():
       self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval())
 
+  def test_transform_with_weighted_column(self):
+    # Github issue 12557
+    ids = fc.categorical_column_with_vocabulary_list(
+        key='ids', vocabulary_list=('a', 'b', 'c'))
+    weights = fc.weighted_categorical_column(ids, 'weights')
+    indicator = fc.indicator_column(weights)
+    features = {
+        'ids': constant_op.constant(['c', 'b', 'a'], shape=(1, 3)),
+        'weights': constant_op.constant([2., 4., 6.], shape=(1, 3))
+    }
+    indicator_tensor = _transform_features(features, [indicator])[indicator]
+    with _initialized_session():
+      self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
+
   def test_linear_model(self):
     animal = fc.indicator_column(
         fc.categorical_column_with_identity('animal', num_buckets=4))
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index 8e31554..f2b35bc 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -82,7 +82,7 @@
 
   // Returns the limit on the number of simultaneously resident blocks
   // on a multiprocessor.
-  const uint64 blocks_per_core_limit() const { return blocks_per_core_limit_; }
+  uint64 blocks_per_core_limit() const { return blocks_per_core_limit_; }
 
   // Returns the limit on the total number of threads that can be launched in a
   // single block; i.e. the limit on x * y * z dimensions of a ThreadDim.
@@ -141,7 +141,7 @@
   uint64 device_memory_size() const { return device_memory_size_; }
 
   // Returns the device's core clock rate in GHz.
-  const float clock_rate_ghz() const { return clock_rate_ghz_; }
+  float clock_rate_ghz() const { return clock_rate_ghz_; }
 
   // Returns whether ECC is enabled.
   bool ecc_enabled() const { return ecc_enabled_; }
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
index d9d40d7..8ef091f 100644
--- a/tensorflow/stream_executor/kernel.h
+++ b/tensorflow/stream_executor/kernel.h
@@ -302,7 +302,7 @@
   //
   // Returns a default-constructed KernelArg if there is no next argument.
   KernelArg next() {
-    KernelArg result;
+    KernelArg result = {};
     if (!has_next()) {
       return result;
     } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index e525e11..4405678 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -276,8 +276,9 @@
   """Check for given lingering strings."""
   formatted_string = lingering_string.replace(".", r"\.")
   try:
-    linger_strs = subprocess.check_output(
-        ['grep', '-rnoH', formatted_string, TF_SRC_DIR]).split("\n")
+    linger_str_output = subprocess.check_output(
+        ["grep", "-rnoH", formatted_string, TF_SRC_DIR])
+    linger_strs = linger_str_output.decode("utf8").split("\n")
   except subprocess.CalledProcessError:
     linger_strs = []
 
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index bb63349..d623169 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -84,6 +84,7 @@
         "//tensorflow/python/saved_model",
         "//tensorflow/python:spectral_ops_test_util",
         "//tensorflow/python/tools:tools_pip",
+        "//tensorflow/python/eager:eager_pip",
         # These targets don't build on Windows yet. Exclude them for now.
         # "//tensorflow/contrib/ndlstm",
         # "//tensorflow/contrib/slim",
diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD
index f593b71..9840d7b 100644
--- a/third_party/sqlite.BUILD
+++ b/third_party/sqlite.BUILD
@@ -2,9 +2,9 @@
 #   Sqlite3 library. Provides utilities for interacting
 #   with sqlite3 databases.
 
-licenses(["notice"])  # BSD/MIT-like license
+licenses(["unencumbered"])  # Public Domain
 
-exports_files(["LICENSE"])
+# exports_files(["LICENSE"])
 
 cc_library(
     name = "sqlite",