Merge pull request #39399 from marload:patch-0511-01

PiperOrigin-RevId: 311208506
Change-Id: Ic4d22e1b01554cfe7eaab759ecbc0c2cadc3c002
diff --git a/.bazelrc b/.bazelrc
index cf15d09..224238d 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -163,6 +163,8 @@
 build:dbg --config=opt -c dbg
 # for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
 build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
+# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
+build:dbg --copt -DDEBUG_BUILD
 
 build:tensorrt --action_env TF_NEED_TENSORRT=1
 
diff --git a/configure.py b/configure.py
index ca2ff59..945c303 100644
--- a/configure.py
+++ b/configure.py
@@ -217,7 +217,8 @@
     elif not os.path.exists(python_bin_path):
       print('Invalid python path: {} cannot be found.'.format(python_bin_path))
     else:
-      print('{} is not executable.  Is it the python binary?'.format(python_bin_path))
+      print('{} is not executable.  Is it the python binary?'.format(
+          python_bin_path))
     environ_cp['PYTHON_BIN_PATH'] = ''
 
   # Convert python path to Windows style before checking lib and version
@@ -320,7 +321,8 @@
       Raise the error to avoid infinitely looping.
   """
   if not question:
-    question = 'Do you wish to build TensorFlow with {} support?'.format(query_item)
+    question = 'Do you wish to build TensorFlow with {} support?'.format(
+        query_item)
   if not yes_reply:
     yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
   if not no_reply:
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index bf3af3c..ab4316d 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -530,6 +530,13 @@
 # TODO(b/154762408) Remove this package group once it's no longer needed.
 package_group(name = "composite_tensor_whitelist")
 
+# Packages that use private types symbols, until they are exported.
+# TODO(b/154650521) Remove.
+package_group(
+    name = "types_whitelist",
+    packages = ["//learning/deepmind/tensorflow/replicator/..."],
+)
+
 filegroup(
     name = "intel_binary_blob",
     data = if_mkl_ml(
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 73c2f78..5c01ccb 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -924,7 +924,7 @@
       context->GetDevicePlacementPolicy());
 }
 
-TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
+TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
   tensorflow::Tensor tensor;
   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
   if (!status->status.ok()) return nullptr;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 070b3a9..5afe304 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -137,7 +137,7 @@
 // placed in memory of different devices or remote address spaces.
 typedef struct TFE_TensorHandle TFE_TensorHandle;
 
-TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
                                                             TF_Status* status);
 // Indicates that the caller will not be using `h` any more.
 TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc
index 8c9aa97..bd99189 100644
--- a/tensorflow/c/eager/c_api_unified_experimental_test.cc
+++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc
@@ -29,7 +29,7 @@
 namespace tensorflow {
 namespace {
 
-TEST(UnifedCAPI, TestBasicEager) {
+TEST(UnifiedCAPI, TestBasicEager) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -81,7 +81,7 @@
   TF_DeleteExecutionContext(ctx);
 }
 
-TEST(UnifedCAPI, TestBasicGraph) {
+TEST(UnifiedCAPI, TestBasicGraph) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
@@ -185,7 +185,7 @@
   TF_DeleteExecutionContext(eager_execution_ctx);
 }
 
-TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
+TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -201,7 +201,7 @@
   TF_DeleteExecutionContext(ctx);
 }
 
-TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
+TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
@@ -222,7 +222,7 @@
   TF_DeleteExecutionContext(graph_ctx);
 }
 
-TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
+TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
@@ -243,7 +243,7 @@
   TF_DeleteExecutionContext(graph_ctx);
 }
 
-TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
+TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
   // Build an Eager context.
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
@@ -289,7 +289,7 @@
   TF_DeleteExecutionContext(graph_ctx);
 }
 
-TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
+TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
index f4dbcc6..3b2640e 100644
--- a/tensorflow/c/eager/parallel_device/BUILD
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -27,6 +27,7 @@
     name = "parallel_device",
     srcs = [":sources"],
     hdrs = [":headers"],
+    visibility = ["//tensorflow:internal"],
     deps = [
         "//tensorflow/c:c_api",
         "//tensorflow/c/eager:c_api",
@@ -43,6 +44,7 @@
     srcs = ["parallel_device_test.cc"],
     deps = [
         ":parallel_device",
+        ":parallel_device_ops",
         "//tensorflow/c:c_api",
         "//tensorflow/c:c_api_experimental",
         "//tensorflow/c/eager:c_api",
@@ -52,3 +54,19 @@
         "//tensorflow/core:test_main",
     ],
 )
+
+# Note: ParallelDevice-specific ops are experimental and not currently linked in
+# to TensorFlow by default, just used in a few tests.
+filegroup(
+    name = "parallel_device_ops_srcs",
+    srcs = ["parallel_device_ops.cc"],
+    visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
+)
+
+cc_library(
+    name = "parallel_device_ops",
+    srcs = [":parallel_device_ops_srcs"],
+    visibility = ["//tensorflow:internal"],
+    deps = ["//tensorflow/core:framework"],
+    alwayslink = 1,
+)
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc
index e684680..27c2699 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device.cc
@@ -92,6 +92,10 @@
                                                        TFE_TensorHandle* tensor,
                                                        TF_Status* status) const;
 
+  // A parallel tensor with scalar integers numbering component devices.
+  std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
+                                            TF_Status* status) const;
+
   // Takes a description of a single operation being executed on the
   // ParallelDevice, and in turn runs one operation per component device with
   // its corresponding inputs from the input ParallelTensors (or
@@ -208,6 +212,46 @@
                                            status);
 }
 
+std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
+    TFE_Context* context, TF_Status* status) const {
+  // TODO(allenl): We could cache DeviceIDs (keyed by context).
+  std::vector<TensorHandlePtr> components;
+  components.reserve(underlying_devices_.size());
+  for (int device_index = 0; device_index < underlying_devices_.size();
+       ++device_index) {
+    int64_t* device_id = new int64_t;
+    *device_id = device_index;
+    std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
+        TF_NewTensor(
+            TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
+            sizeof(int64_t),
+            [](void* data, size_t, void* arg) {
+              delete reinterpret_cast<int64_t*>(data);
+            },
+            nullptr),
+        TF_DeleteTensor);
+    // TODO(allenl): Here and when executing regular operations, we could hold
+    // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
+    // device names repeatedly.
+    OpPtr const_op(TFE_NewOp(context, "Const", status));
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+    TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
+                    status);
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+    TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+    TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
+    TFE_TensorHandle* device_handle;
+    int num_outputs = 1;
+    TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+    components.emplace_back(device_handle);
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+  }
+  return ParallelTensor::FromTensorHandles(*this, std::move(components),
+                                           status);
+}
+
 absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
     TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
     const char* operation_name, const TFE_OpAttrs* attributes,
@@ -282,6 +326,13 @@
     }
     result.emplace(std::move(outputs));
     return result;
+  } else if (operation_name == std::string("DeviceID")) {
+    std::vector<MaybeParallelTensorOwned> result_content;
+    result_content.reserve(1);
+    result_content.push_back(DeviceIDs(context, status));
+    if (TF_GetCode(status) != TF_OK) return result;
+    result.emplace(std::move(result_content));
+    return result;
   }
   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
       maybe_parallel_results(
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc
new file mode 100644
index 0000000..1decffc
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+// TODO(allenl): Figure out if we need this op, and if so whether we should move
+// it to core TF. Right now the eager C API does some checking of op
+// registrations before calling into custom devices, but we may be able to avoid
+// that.
+REGISTER_OP("DeviceID")
+    .Output("device_id: int64")
+    .SetIsStateful()
+    .SetShapeFn(tensorflow::shape_inference::ScalarShape);
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
index 9b0613b..fdc1404 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
@@ -278,14 +278,15 @@
 }
 
 // Assert that `handle` is equal to `expected_value`.
-void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
+template <typename value_type>
+void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
       TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-  ASSERT_EQ(expected_value,
-            *static_cast<float*>(TF_TensorData(value_zero.get())));
+  EXPECT_EQ(expected_value,
+            *static_cast<value_type*>(TF_TensorData(value_zero.get())));
 }
 
 template <std::size_t num_devices>
@@ -343,8 +344,8 @@
     ExtractPerDeviceValues(context, read.get(), &components, status.get());
     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
-    AssertScalarFloatEq(components[0].get(), 20.);
-    AssertScalarFloatEq(components[1].get(), 20.);
+    ExpectScalarEq<float>(components[0].get(), 20.);
+    ExpectScalarEq<float>(components[1].get(), 20.);
 
     std::string first_device =
         TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@@ -373,8 +374,8 @@
     ExtractPerDeviceValues(context, read.get(), &components, status.get());
     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
-    AssertScalarFloatEq(components[0].get(), 23.);
-    AssertScalarFloatEq(components[1].get(), 18.);
+    ExpectScalarEq<float>(components[0].get(), 23.);
+    ExpectScalarEq<float>(components[1].get(), 18.);
 
     std::string first_device =
         TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@@ -383,6 +384,32 @@
         TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
     ASSERT_EQ(underlying_devices[1], second_device);
   }
+  // Compute the device ID twice and verify the result
+  for (int i = 0; i < 2; ++i) {
+    std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
+        TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
+    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+    TFE_OpSetDevice(op.get(), device_name, status.get());
+    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+    TFE_TensorHandle* result_handle;
+    int num_retvals = 1;
+    TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
+    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+    std::array<TensorHandlePtr, 2> components;
+    ExtractPerDeviceValues(context, result_handle, &components, status.get());
+    TFE_DeleteTensorHandle(result_handle);
+    ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+    ExpectScalarEq<int64_t>(components[0].get(), 0);
+    ExpectScalarEq<int64_t>(components[1].get(), 1);
+    std::string first_device =
+        TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
+    ASSERT_EQ(underlying_devices[0], first_device);
+    std::string second_device =
+        TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
+    ASSERT_EQ(underlying_devices[1], second_device);
+  }
 }
 
 TEST(PARALLEL_DEVICE, TestBasicCPU) {
@@ -498,8 +525,8 @@
     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
     // The value of the original tensor is replicated on each device.
-    AssertScalarFloatEq(components[0].get(), 3.);
-    AssertScalarFloatEq(components[1].get(), 3.);
+    ExpectScalarEq<float>(components[0].get(), 3.);
+    ExpectScalarEq<float>(components[1].get(), 3.);
 
     // Verify that the mirrors are placed on the component devices.
     std::string first_device =
@@ -630,7 +657,7 @@
                          &second_components, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
-  AssertScalarFloatEq(second_components[1].get(), 9.);
+  ExpectScalarEq<float>(second_components[1].get(), 9.);
 
   // Verify that the mirrors are placed on the component devices.
   std::string first_device = TFE_TensorHandleBackingDeviceName(
@@ -644,8 +671,8 @@
   std::array<TensorHandlePtr, 2> first_components;
   ExtractPerDeviceValues(context.get(), second_components[0].get(),
                          &first_components, status.get());
-  AssertScalarFloatEq(first_components[0].get(), 3.);
-  AssertScalarFloatEq(first_components[1].get(), 6.);
+  ExpectScalarEq<float>(first_components[0].get(), 3.);
+  ExpectScalarEq<float>(first_components[1].get(), 6.);
 
   first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
                                                    status.get());
@@ -806,8 +833,8 @@
   ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
                          status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-  AssertScalarFloatEq(result_components[0].get(), 3.);
-  AssertScalarFloatEq(result_components[1].get(), 3.);
+  ExpectScalarEq<float>(result_components[0].get(), 3.);
+  ExpectScalarEq<float>(result_components[1].get(), 3.);
 }
 
 void RegisterCollectiveMulFunction(TFE_Context* context,
@@ -909,8 +936,8 @@
   ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
                          status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-  AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
-  AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
+  ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
+  ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
 
   std::string first_device = TFE_TensorHandleBackingDeviceName(
       result_components[0].get(), status.get());
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index e8cb40f..e1fad8e 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -178,7 +178,7 @@
     name = "ops",
     srcs = ["framework/ops.cc"],
     hdrs = ["framework/ops.h"],
-    android_deps = ["//tensorflow/core:android_tensorflow_lib"],
+    android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
     deps = [
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
@@ -197,7 +197,7 @@
         "framework/scope_internal.h",
     ],
     hdrs = ["framework/scope.h"],
-    android_deps = ["//tensorflow/core:android_tensorflow_lib"],
+    android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
     common_deps = [
         ":ops",
     ],
@@ -237,7 +237,7 @@
     name = "client_session",
     srcs = ["client/client_session.cc"],
     hdrs = ["client/client_session.h"],
-    android_deps = ["//tensorflow/core:android_tensorflow_lib"],
+    android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
     common_deps = [
         ":ops",
         ":scope",
@@ -275,7 +275,7 @@
     srcs = ["ops/const_op.cc"],
     hdrs = ["ops/const_op.h"],
     android_deps = [
-        "//tensorflow/core:android_tensorflow_lib",
+        "//tensorflow/core:portable_tensorflow_lib",
     ],
     common_deps = [
         ":ops",
@@ -304,7 +304,7 @@
     srcs = ["ops/while_loop.cc"],
     hdrs = ["ops/while_loop.h"],
     android_deps = [
-        "//tensorflow/core:android_tensorflow_lib",
+        "//tensorflow/core:portable_tensorflow_lib",
     ],
     common_deps = [
         ":cc_ops",
diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD
index 93acf1b..045d4e6 100644
--- a/tensorflow/cc/experimental/base/public/BUILD
+++ b/tensorflow/cc/experimental/base/public/BUILD
@@ -62,3 +62,17 @@
         "//tensorflow/c:tf_tensor",
     ],
 )
+
+cc_library(
+    name = "tensorhandle",
+    hdrs = [
+        "tensorhandle.h",
+    ],
+    deps = [
+        ":runtime",
+        ":status",
+        ":tensor",
+        "//tensorflow/c/eager:c_api",
+        "//tensorflow/c/eager:c_api_experimental",
+    ],
+)
diff --git a/tensorflow/cc/experimental/base/public/runtime.h b/tensorflow/cc/experimental/base/public/runtime.h
index 47fd886..711a38c 100644
--- a/tensorflow/cc/experimental/base/public/runtime.h
+++ b/tensorflow/cc/experimental/base/public/runtime.h
@@ -21,6 +21,7 @@
 #include "tensorflow/c/eager/c_api_experimental.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // Runtime represents an opaque instance of a Tensorflow runtime, with its own
@@ -40,6 +41,7 @@
  private:
   friend class RuntimeBuilder;
   friend class SavedModelAPI;
+  friend class TensorHandle;
 
   // Wraps a TFE_Context. Takes ownership of ctx.
   explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
@@ -63,6 +65,7 @@
 };
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
diff --git a/tensorflow/cc/experimental/base/public/runtime_builder.h b/tensorflow/cc/experimental/base/public/runtime_builder.h
index ed3c93a..737e06c 100644
--- a/tensorflow/cc/experimental/base/public/runtime_builder.h
+++ b/tensorflow/cc/experimental/base/public/runtime_builder.h
@@ -24,6 +24,7 @@
 #include "tensorflow/cc/experimental/base/public/status.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
@@ -79,6 +80,7 @@
 }
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
diff --git a/tensorflow/cc/experimental/base/public/status.h b/tensorflow/cc/experimental/base/public/status.h
index f91f2ca..98c8cf6 100644
--- a/tensorflow/cc/experimental/base/public/status.h
+++ b/tensorflow/cc/experimental/base/public/status.h
@@ -22,6 +22,7 @@
 #include "tensorflow/c/tf_status.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // Status is a wrapper around an error code and an optional error message.
@@ -57,6 +58,7 @@
   friend class RuntimeBuilder;
   friend class Runtime;
   friend class SavedModelAPI;
+  friend class TensorHandle;
 
   // Wraps a TF_Status*, and takes ownership of it.
   explicit Status(TF_Status* status) : status_(status) {}
@@ -88,6 +90,7 @@
 }
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
diff --git a/tensorflow/cc/experimental/base/public/tensor.h b/tensorflow/cc/experimental/base/public/tensor.h
index 26b0e5d..fc44726 100644
--- a/tensorflow/cc/experimental/base/public/tensor.h
+++ b/tensorflow/cc/experimental/base/public/tensor.h
@@ -28,6 +28,7 @@
 #include "tensorflow/cc/experimental/base/public/status.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // Tensor represents an n-dimensional array of values.
@@ -168,6 +169,7 @@
 }
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
diff --git a/tensorflow/cc/experimental/base/public/tensorhandle.h b/tensorflow/cc/experimental/base/public/tensorhandle.h
new file mode 100644
index 0000000..99453ee
--- /dev/null
+++ b/tensorflow/cc/experimental/base/public/tensorhandle.h
@@ -0,0 +1,98 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
+#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/cc/experimental/base/public/runtime.h"
+#include "tensorflow/cc/experimental/base/public/status.h"
+#include "tensorflow/cc/experimental/base/public/tensor.h"
+
+namespace tensorflow {
+namespace experimental {
+namespace cc {
+
+// An opaque representation of a tensor computed/managed by the Tensorflow
+// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
+// to tensors placed in memory of different devices or remote address spaces.
+// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
+// from it.
+class TensorHandle {
+ public:
+  // Unwraps a Tensor from the given TensorHandle. If an error occurred,
+  // status->ok() will be false, and the returned Tensor must not be used.
+  Tensor Resolve(Status* status);
+
+  // Constructs a TensorHandle from a Tensor. If an error occurred,
+  // status->ok() will be false, and the returned TensorHandle must not be used.
+  static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
+                                 Status* status);
+
+  // TensorHandle is movable, and not copyable
+  TensorHandle(TensorHandle&&) = default;
+  TensorHandle& operator=(TensorHandle&&) = default;
+
+ private:
+  // Wraps a TFE_TensorHandle. Takes ownership of handle.
+  explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
+
+  // TensorHandle is not copyable
+  TensorHandle(const TensorHandle&) = delete;
+  TensorHandle& operator=(const TensorHandle&) = delete;
+
+  // Returns the underlying TFE_TensorHandle that this object wraps.
+  // This object retains ownership of the pointer.
+  TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
+
+  // Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
+  // and takes ownership of handle.
+  void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
+
+  struct TFETensorHandleDeleter {
+    void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
+  };
+  std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
+};
+
+inline Tensor TensorHandle::Resolve(Status* status) {
+  TF_Tensor* tensor =
+      TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
+  if (!status->ok()) {
+    return Tensor(nullptr);
+  }
+  return Tensor(tensor);
+}
+
+inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
+                                             const Runtime& runtime,
+                                             Status* status) {
+  TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
+      runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
+  if (!status->ok()) {
+    return TensorHandle(nullptr);
+  }
+  return TensorHandle(tensor_handle);
+}
+
+}  // namespace cc
+}  // namespace experimental
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD
index a2b634a..f449d61 100644
--- a/tensorflow/cc/experimental/base/tests/BUILD
+++ b/tensorflow/cc/experimental/base/tests/BUILD
@@ -5,12 +5,22 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
+cc_library(
+    name = "tensor_types_test_util",
+    testonly = True,
+    hdrs = ["tensor_types_test_util.h"],
+    deps = [
+        "//tensorflow/c:tf_datatype",
+    ],
+)
+
 tf_cc_test(
     name = "tensor_test",
     srcs = [
         "tensor_test.cc",
     ],
     deps = [
+        ":tensor_types_test_util",
         "//tensorflow/c:tf_datatype",
         "//tensorflow/cc/experimental/base/public:status",
         "//tensorflow/cc/experimental/base/public:tensor",
@@ -19,3 +29,22 @@
         "//tensorflow/core:test_main",
     ],
 )
+
+tf_cc_test(
+    name = "tensorhandle_test",
+    srcs = [
+        "tensorhandle_test.cc",
+    ],
+    deps = [
+        ":tensor_types_test_util",
+        "//tensorflow/c:tf_datatype",
+        "//tensorflow/cc/experimental/base/public:runtime",
+        "//tensorflow/cc/experimental/base/public:runtime_builder",
+        "//tensorflow/cc/experimental/base/public:status",
+        "//tensorflow/cc/experimental/base/public:tensor",
+        "//tensorflow/cc/experimental/base/public:tensorhandle",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc
index 86a50ba..33f9ab6 100644
--- a/tensorflow/cc/experimental/base/tests/tensor_test.cc
+++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc
@@ -16,69 +16,22 @@
 #include "tensorflow/cc/experimental/base/public/tensor.h"
 
 #include <stddef.h>
-
-#include <cstdint>
+#include <stdint.h>
 
 #include "tensorflow/c/tf_datatype.h"
+#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/platform/test.h"
 
-namespace tensorflow {
 namespace {
 
-// Each of the following struct types have two members: a kDType that
-// corresponds to a TF_Datatype enum value, and a typedef "type"
-// of its corresponding C++ type. These types allow us to write Dtype-agnostic
-// tests via GoogleTest's TypedTests:
-// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
-struct FloatType {
-  using type = float;
-  static constexpr TF_DataType kDType = TF_FLOAT;
-};
+using tensorflow::experimental::cc::Status;
+using tensorflow::experimental::cc::Tensor;
 
-struct DoubleType {
-  using type = double;
-  static constexpr TF_DataType kDType = TF_DOUBLE;
-};
-
-struct Int32Type {
-  using type = int32_t;
-  static constexpr TF_DataType kDType = TF_INT32;
-};
-
-struct UINT8Type {
-  using type = uint8_t;
-  static constexpr TF_DataType kDType = TF_UINT8;
-};
-
-struct INT8Type {
-  using type = int8_t;
-  static constexpr TF_DataType kDType = TF_INT8;
-};
-
-struct INT64Type {
-  using type = int64_t;
-  static constexpr TF_DataType kDType = TF_INT64;
-};
-
-struct UINT16Type {
-  using type = uint16_t;
-  static constexpr TF_DataType kDType = TF_UINT16;
-};
-
-struct UINT32Type {
-  using type = uint32_t;
-  static constexpr TF_DataType kDType = TF_UINT32;
-};
-
-struct UINT64Type {
-  using type = uint64_t;
-  static constexpr TF_DataType kDType = TF_UINT64;
-};
-
-using SimpleTypes =
-    ::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
-                     INT64Type, UINT16Type, UINT32Type, UINT64Type>;
+using SimpleTypes = ::testing::Types<
+    tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
+    tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
+    tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
 
 template <typename T>
 class ConstructScalarTensorTest : public ::testing::Test {};
@@ -88,14 +41,13 @@
 // and verifies the expected dimensions, dtype, value, number of bytes, and
 // number of elements.
 TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
-  cc::Status status;
+  Status status;
   TF_DataType dtype = TypeParam::kDType;
   typename TypeParam::type value = 42;
-  cc::Tensor tensor =
-      cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
-                             /*data=*/&value,
-                             /*len=*/sizeof(value),
-                             /*deleter=*/[](void*, size_t) {}, &status);
+  Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
+                                     /*data=*/&value,
+                                     /*len=*/sizeof(value),
+                                     /*deleter=*/[](void*, size_t) {}, &status);
   ASSERT_TRUE(status.ok()) << status.message();
 
   EXPECT_EQ(tensor.dims(), 0);
@@ -113,7 +65,7 @@
 // and verifies the expected dimensions, dtype, value, number of bytes, and
 // number of elements.
 TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
-  cc::Status status;
+  Status status;
   TF_DataType dtype = TypeParam::kDType;
   // This is our 1D tensor of varying dtype.
   std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
@@ -121,7 +73,7 @@
   std::vector<int64_t> shape;
   shape.push_back(value.size());
 
-  cc::Tensor tensor = cc::Tensor::FromBuffer(
+  Tensor tensor = Tensor::FromBuffer(
       /*dtype=*/dtype, /*shape=*/shape,
       /*data=*/value.data(),
       /*len=*/value.size() * sizeof(typename TypeParam::type),
@@ -130,7 +82,7 @@
 
   EXPECT_EQ(tensor.dims(), 1);
   EXPECT_EQ(tensor.dtype(), dtype);
-  gtl::ArraySlice<typename TypeParam::type> tensor_view(
+  tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
       reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
   EXPECT_EQ(tensor_view[0], 42);
   EXPECT_EQ(tensor_view[1], 100);
@@ -152,14 +104,14 @@
 // and verifies the expected dimensions, dtype, value, number of bytes, and
 // number of elements.
 TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
-  cc::Status status;
+  Status status;
   TF_DataType dtype = TypeParam::kDType;
   // This is our 1D tensor of varying dtype.
   std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
   // Shape is Rank 2 vector with shape 2 x 3.
   std::vector<int64_t> shape({2, 3});
 
-  cc::Tensor tensor = cc::Tensor::FromBuffer(
+  Tensor tensor = Tensor::FromBuffer(
       /*dtype=*/dtype, /*shape=*/shape,
       /*data=*/value.data(),
       /*len=*/value.size() * sizeof(typename TypeParam::type),
@@ -169,7 +121,7 @@
 
   EXPECT_EQ(tensor.dims(), 2);
   EXPECT_EQ(tensor.dtype(), dtype);
-  gtl::ArraySlice<typename TypeParam::type> tensor_view(
+  tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
       reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
   EXPECT_EQ(tensor_view[0], 42);
   EXPECT_EQ(tensor_view[1], 100);
@@ -185,22 +137,22 @@
 
 TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
   bool done = false;
-  cc::Status status;
+  Status status;
   std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
   {
     // data_vector is a rank 1 tensor.
     std::vector<int64_t> shape;
     shape.push_back(data_vector.size());
 
-    cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
+    Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
       done = true;
     };
 
-    cc::Tensor tensor =
-        cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
-                               /*data=*/data_vector.data(),
-                               /*len=*/data_vector.size() * sizeof(int32_t),
-                               /*deleter=*/callback, &status);
+    Tensor tensor =
+        Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
+                           /*data=*/data_vector.data(),
+                           /*len=*/data_vector.size() * sizeof(int32_t),
+                           /*deleter=*/callback, &status);
     ASSERT_TRUE(status.ok()) << status.message();
   }
   // At this point, tensor has been destroyed, and the deleter callback should
@@ -209,4 +161,3 @@
 }
 
 }  // namespace
-}  // namespace tensorflow
diff --git a/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h
new file mode 100644
index 0000000..af9cad7
--- /dev/null
+++ b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h
@@ -0,0 +1,76 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
+#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
+
+#include <stdint.h>
+
+#include "tensorflow/c/tf_datatype.h"
+
+namespace tensorflow {
+
+// Each of the following struct types have two members: a kDType that
+// corresponds to a TF_Datatype enum value, and a typedef "type"
+// of its corresponding C++ type. These types allow us to write Dtype-agnostic
+// tests via GoogleTest's TypedTests:
+// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
+struct FloatType {
+  using type = float;
+  static constexpr TF_DataType kDType = TF_FLOAT;
+};
+
+struct DoubleType {
+  using type = double;
+  static constexpr TF_DataType kDType = TF_DOUBLE;
+};
+
+struct Int32Type {
+  using type = int32_t;
+  static constexpr TF_DataType kDType = TF_INT32;
+};
+
+struct UINT8Type {
+  using type = uint8_t;
+  static constexpr TF_DataType kDType = TF_UINT8;
+};
+
+struct INT8Type {
+  using type = int8_t;
+  static constexpr TF_DataType kDType = TF_INT8;
+};
+
+struct INT64Type {
+  using type = int64_t;
+  static constexpr TF_DataType kDType = TF_INT64;
+};
+
+struct UINT16Type {
+  using type = uint16_t;
+  static constexpr TF_DataType kDType = TF_UINT16;
+};
+
+struct UINT32Type {
+  using type = uint32_t;
+  static constexpr TF_DataType kDType = TF_UINT32;
+};
+
+struct UINT64Type {
+  using type = uint64_t;
+  static constexpr TF_DataType kDType = TF_UINT64;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
new file mode 100644
index 0000000..cfeaba4
--- /dev/null
+++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
@@ -0,0 +1,184 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <memory>
+
+#include "tensorflow/c/tf_datatype.h"
+#include "tensorflow/cc/experimental/base/public/runtime.h"
+#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
+#include "tensorflow/cc/experimental/base/public/tensor.h"
+#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+using tensorflow::experimental::cc::Runtime;
+using tensorflow::experimental::cc::RuntimeBuilder;
+using tensorflow::experimental::cc::Status;
+using tensorflow::experimental::cc::Tensor;
+using tensorflow::experimental::cc::TensorHandle;
+
+using SimpleTypes = ::testing::Types<
+    tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
+    tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
+    tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
+
+template <typename T>
+class ConstructScalarTensorHandleTest : public ::testing::Test {};
+TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
+
+// This test constructs a scalar tensor for each of the types in "SimpleTypes",
+// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
+// verify the expected dims, dtype, value, num bytes, and num elements.
+TYPED_TEST(ConstructScalarTensorHandleTest,
+           ValidTensorAttributesAfterConstruction) {
+  Status status;
+  RuntimeBuilder runtime_builder;
+  std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  TF_DataType dtype = TypeParam::kDType;
+  typename TypeParam::type value = 42;
+  Tensor original_tensor =
+      Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
+                         /*data=*/&value,
+                         /*len=*/sizeof(value),
+                         /*deleter=*/[](void*, size_t) {}, &status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  TensorHandle handle =
+      TensorHandle::FromTensor(original_tensor, *runtime, &status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  Tensor tensor = handle.Resolve(&status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  EXPECT_EQ(tensor.dims(), 0);
+  EXPECT_EQ(tensor.dtype(), dtype);
+  EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
+  EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
+  EXPECT_EQ(tensor.num_elements(), 1);
+}
+
+template <typename T>
+class Construct1DTensorHandleTest : public ::testing::Test {};
+TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
+
+// This test constructs a 1D tensor for each of the types in "SimpleTypes",
+// and verifies the expected dimensions, dtype, value, number of bytes, and
+// number of elements.
+TYPED_TEST(Construct1DTensorHandleTest,
+           ValidTensorAttributesAfterConstruction) {
+  Status status;
+  RuntimeBuilder runtime_builder;
+  std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  TF_DataType dtype = TypeParam::kDType;
+  // This is our 1D tensor of varying dtype.
+  std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
+  // Shape is Rank 1 vector.
+  std::vector<int64_t> shape;
+  shape.push_back(value.size());
+
+  Tensor original_tensor = Tensor::FromBuffer(
+      /*dtype=*/dtype, /*shape=*/shape,
+      /*data=*/value.data(),
+      /*len=*/value.size() * sizeof(typename TypeParam::type),
+      /*deleter=*/[](void*, size_t) {}, &status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  TensorHandle handle =
+      TensorHandle::FromTensor(original_tensor, *runtime, &status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  Tensor tensor = handle.Resolve(&status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  EXPECT_EQ(tensor.dims(), 1);
+  EXPECT_EQ(tensor.dtype(), dtype);
+  tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
+      reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
+  EXPECT_EQ(tensor_view[0], 42);
+  EXPECT_EQ(tensor_view[1], 100);
+  EXPECT_EQ(tensor_view[2], 0);
+  EXPECT_EQ(tensor_view[3], 1);
+  EXPECT_EQ(tensor_view[4], 4);
+  EXPECT_EQ(tensor_view[5], 29);
+
+  EXPECT_EQ(tensor.num_bytes(),
+            value.size() * sizeof(typename TypeParam::type));
+  EXPECT_EQ(tensor.num_elements(), value.size());
+}
+
+template <typename T>
+class Construct2DTensorHandleTest : public ::testing::Test {};
+TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
+
+// This test constructs a 2D tensor for each of the types in "SimpleTypes",
+// and verifies the expected dimensions, dtype, value, number of bytes, and
+// number of elements.
+TYPED_TEST(Construct2DTensorHandleTest,
+           ValidTensorAttributesAfterConstruction) {
+  Status status;
+  RuntimeBuilder runtime_builder;
+  std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  TF_DataType dtype = TypeParam::kDType;
+  // This is our 1D tensor of varying dtype.
+  std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
+  // Shape is Rank 2 vector with shape 2 x 3.
+  std::vector<int64_t> shape({2, 3});
+
+  Tensor original_tensor = Tensor::FromBuffer(
+      /*dtype=*/dtype, /*shape=*/shape,
+      /*data=*/value.data(),
+      /*len=*/value.size() * sizeof(typename TypeParam::type),
+      /*deleter=*/[](void*, size_t) {}, &status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  TensorHandle handle =
+      TensorHandle::FromTensor(original_tensor, *runtime, &status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  Tensor tensor = handle.Resolve(&status);
+  ASSERT_TRUE(status.ok()) << status.message();
+
+  EXPECT_EQ(tensor.dims(), 2);
+  EXPECT_EQ(tensor.dtype(), dtype);
+  tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
+      reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
+  EXPECT_EQ(tensor_view[0], 42);
+  EXPECT_EQ(tensor_view[1], 100);
+  EXPECT_EQ(tensor_view[2], 0);
+  EXPECT_EQ(tensor_view[3], 1);
+  EXPECT_EQ(tensor_view[4], 4);
+  EXPECT_EQ(tensor_view[5], 29);
+
+  EXPECT_EQ(tensor.num_bytes(),
+            value.size() * sizeof(typename TypeParam::type));
+  EXPECT_EQ(tensor.num_elements(), value.size());
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/tensorflow/cc/saved_model/experimental/public/concrete_function.h
index f57ba05..1adaf70 100644
--- a/tensorflow/cc/saved_model/experimental/public/concrete_function.h
+++ b/tensorflow/cc/saved_model/experimental/public/concrete_function.h
@@ -24,6 +24,7 @@
 #include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
@@ -54,6 +55,7 @@
 }
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h
index bab9527..88cb779 100644
--- a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h
+++ b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h
@@ -22,6 +22,7 @@
 #include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // ConcreteFunctionList helps convert an opaque pointer to an array of
@@ -56,6 +57,7 @@
 }
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
diff --git a/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/tensorflow/cc/saved_model/experimental/public/function_metadata.h
index c3dcc45..11e1a86 100644
--- a/tensorflow/cc/saved_model/experimental/public/function_metadata.h
+++ b/tensorflow/cc/saved_model/experimental/public/function_metadata.h
@@ -21,6 +21,7 @@
 #include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // FunctionMetadata stores additional function information, including
@@ -40,6 +41,7 @@
 };
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h
index 814479d..04018bf 100644
--- a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h
+++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h
@@ -28,6 +28,7 @@
 #include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
 
 namespace tensorflow {
+namespace experimental {
 namespace cc {
 
 // SavedModelAPI offers a way to load Tensorflow Saved Models
@@ -155,6 +156,7 @@
 }
 
 }  // namespace cc
+}  // namespace experimental
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc
index 155c586..7f7f6b0 100644
--- a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc
+++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc
@@ -26,10 +26,14 @@
 #include "tensorflow/core/platform/stringpiece.h"
 #include "tensorflow/core/platform/test.h"
 
-namespace tensorflow {
 
 namespace {
 
+using tensorflow::experimental::cc::Runtime;
+using tensorflow::experimental::cc::RuntimeBuilder;
+using tensorflow::experimental::cc::SavedModelAPI;
+using tensorflow::experimental::cc::Status;
+
 constexpr char kTestData[] = "cc/saved_model/testdata";
 
 std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
@@ -43,21 +47,21 @@
 class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
 
 TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
-  cc::Status status;
-  cc::RuntimeBuilder builder;
+  Status status;
+  RuntimeBuilder builder;
   bool use_tfrt = GetParam();
   if (use_tfrt) {
     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
   }
 
   builder.SetUseTFRT(use_tfrt);
-  std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
+  std::unique_ptr<Runtime> runtime = builder.Build(&status);
   ASSERT_TRUE(status.ok()) << status.message();
 
   std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
   std::unordered_set<std::string> tags = {"serve"};
-  std::unique_ptr<cc::SavedModelAPI> model =
-      cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
+  std::unique_ptr<SavedModelAPI> model =
+      SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
 
   // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
   // That unblocks writing other tests that require a TF_SavedModel*,
@@ -67,20 +71,20 @@
 }
 
 TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
-  cc::Status status;
-  cc::RuntimeBuilder builder;
+  Status status;
+  RuntimeBuilder builder;
   bool use_tfrt = GetParam();
   if (use_tfrt) {
     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
   }
 
   builder.SetUseTFRT(use_tfrt);
-  std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
+  std::unique_ptr<Runtime> runtime = builder.Build(&status);
   ASSERT_TRUE(status.ok()) << status.message();
 
   std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
-  std::unique_ptr<cc::SavedModelAPI> model =
-      cc::SavedModelAPI::Load(model_dir, *runtime, &status);
+  std::unique_ptr<SavedModelAPI> model =
+      SavedModelAPI::Load(model_dir, *runtime, &status);
 
   // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
   // That unblocks writing other tests that require a TF_SavedModel*,
@@ -94,4 +98,3 @@
 
 }  // namespace
 
-}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index d907d28..f99b280 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -695,9 +695,9 @@
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:LoopOpsTransforms",
         "@llvm-project//mlir:MlirTranslateMain",
         "@llvm-project//mlir:QuantOps",
+        "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Translation",
diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py
index 6d3131a..f1271d0 100644
--- a/tensorflow/compiler/mlir/runlit.cfg.py
+++ b/tensorflow/compiler/mlir/runlit.cfg.py
@@ -70,9 +70,9 @@
 ]
 tool_names = [
     'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
-    'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
-    'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
-    'xla-opt'
+    'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
+    'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
+    'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
 ]
 tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
 llvm_config.add_tool_substitutions(tools, tool_dirs)
diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py
index 661e620..3e7596c 100644
--- a/tensorflow/compiler/mlir/runlit.site.cfg.py
+++ b/tensorflow/compiler/mlir/runlit.site.cfg.py
@@ -44,6 +44,7 @@
     'tensorflow/compiler/mlir',
     'tensorflow/compiler/mlir/lite',
     'tensorflow/compiler/mlir/tensorflow',
+    'tensorflow/compiler/mlir/tfjs',
     'tensorflow/compiler/mlir/xla',
     'tensorflow/compiler/aot',
     'tensorflow/compiler/xla/service/mlir_gpu',
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 9099f2b..0edf0f3 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -556,7 +556,7 @@
     deps = [
         ":tensorflow",
         "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:LoopOpsTransforms",
+        "@llvm-project//mlir:SCFTransforms",
     ],
     alwayslink = 1,
 )
@@ -823,6 +823,7 @@
         ":mangling_util",
         ":tensorflow_attributes",
         ":tensorflow_types",
+        "//tensorflow/compiler/xla:util",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index a33b339..9a29fa4 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -192,6 +192,44 @@
   let verifier = [{ return Verify(*this); }];
 }
 
+def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> {
+  let summary = "An Op to exchange data across TPU replicas.";
+
+  let description = [{
+On each replica, the input is split into `split_count` blocks along
+`split_dimension` and send to the other replicas given group_assignment. After
+receiving `split_count` - 1 blocks from other replicas, we concatenate the
+blocks along `concat_dimension` as the output.
+
+For example, suppose there are 2 TPU replicas:
+replica 0 receives input: `[[A, B]]`
+replica 1 receives input: `[[C, D]]`
+
+group_assignment=`[[0, 1]]`
+concat_dimension=0
+split_dimension=1
+split_count=2
+
+replica 0's output: `[[A], [C]]`
+replica 1's output: `[[B], [D]]`
+  }];
+
+  let arguments = (ins
+    TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
+    I32Tensor:$group_assignment,
+
+    I64Attr:$concat_dimension,
+    I64Attr:$split_dimension,
+    I64Attr:$split_count
+  );
+
+  let results = (outs
+    TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> {
   let summary = "Returns the argument of a complex number.";
 
@@ -1408,6 +1446,30 @@
   let hasCanonicalizer = 1;
 }
 
+def TF_ConjugateTransposeOp : TF_Op<"ConjugateTranspose", [NoSideEffect]> {
+  let summary = [{
+Shuffle dimensions of x according to a permutation and conjugate the result.
+  }];
+
+  let description = [{
+The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+  `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+  `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])`
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$x,
+    TF_I32OrI64Tensor:$perm
+  );
+
+  let results = (outs
+    TF_Tensor:$y
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>;
+}
+
 def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
   let summary = [{
 Computes a 2-D convolution given 4-D `input` and `filter` tensors.
@@ -3563,6 +3625,31 @@
   TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
 }
 
+def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
+  let summary = [{
+    Create a copy of `x` with the updated specified rows 'i' with values 'v'.
+
+  }];
+
+  let description = [{
+    Creates a copy of tensor 'x' and updates the columns specified in tensor 'i'
+    with the values 'v'. Originally this function was mutative however for 
+    compilation we make this operation create / operate on a copy.
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$x,
+    I32Tensor:$i,
+    TF_Tensor:$v
+  );
+
+  let results = (outs
+    TF_Tensor:$y
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
 def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes the reciprocal of x element-wise.";
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index edba135..85baff5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -986,18 +986,15 @@
     MLIRContext *context, Optional<Location> location, ValueRange operands,
     DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<Type> &inferredReturnTypes) {
-  for (NamedAttribute named_attr : attributes) {
-    if (named_attr.first.strref() != "value") continue;
-    auto value = named_attr.second;
-    if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
-      inferredReturnTypes.assign({elem_attr.getType()});
-      return success();
-    }
-    return emitOptionalError(location,
-                             "attribute 'value' failed to satisfy constraint: "
-                             "constant vector/tensor");
+  auto value = attributes.get("value");
+  if (!value) return emitOptionalError(location, "missing attribute 'value'");
+  if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
+    inferredReturnTypes.assign({elem_attr.getType()});
+    return success();
   }
-  return emitOptionalError(location, "missing attribute 'value'");
+  return emitOptionalError(location,
+                           "attribute 'value' failed to satisfy constraint: "
+                           "constant vector/tensor");
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
index 77ca08c..eb67bdc 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
@@ -1,13 +1,17 @@
 // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure
 
-// Tests extraction of a single outside compiled cluster with no input or output dependecies.
+// Tests extraction of a outside compiled ops at head of TPU computation.
 
-// CHECK-LABEL: func @nodep_single_head_outside_compilation
-func @nodep_single_head_outside_compilation() -> () {
-   // CHECK: "tf.A"
-   // CHECK-NEXT: "tf_device.launch"
-  "tf_device.launch"() ( {
-    "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
+func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
+  // CHECK:      tf_device.launch
+  // CHECK:        "tf.A"
+  // CHECK-NEXT:   tf_device.return
+  //
+  // CHECK:      "tf_device.cluster"
+  // CHECK:        "tf.C"
+  // CHECK-NEXT:   tf_device.return
+  "tf_device.cluster"() ( {
+    "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
     "tf.B"() : () -> ()
     "tf.C"() : () -> ()
     tf_device.return
@@ -15,15 +19,62 @@
   return
 }
 
-// CHECK-LABEL: func @nodep_multiple_head_outside_compilation
-func @nodep_multiple_head_outside_compilation() -> () {
-   // CHECK: "tf.A"
-   // CHECK-NEXT: "tf.B"
-   // CHECK-NEXT: "tf_device.launch"
-  "tf_device.launch"() ( {
-    "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
-    "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
-    "tf.C"() : () -> ()
+// CHECK-LABEL: func @multiple_head_outside_compilation
+func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
+  // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
+  // CHECK:        %[[A_OUT:.*]] = "tf.A"
+  // CHECK:        %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
+  // CHECK:        "tf.C"
+  // CHECK-NEXT:   tf_device.return %[[B_OUT]]
+  //
+  // CHECK:      "tf_device.cluster"
+  // CHECK:        "tf.D"(%[[LAUNCH_OUT]])
+  // CHECK-NEXT:   tf_device.return
+  "tf_device.cluster"() ( {
+    %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
+    %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
+    "tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
+    "tf.D"(%1) : (tensor<i32>) -> ()
+    tf_device.return
+  }) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
+func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
+  // CHECK-NOT:  tf_device.launch
+  // CHECK:      "tf_device.cluster"
+  // CHECK-NEXT:   "tf.A"
+  // CHECK-NEXT:   "tf.B"
+  // CHECK-NEXT:   "tf.C"
+  // CHECK-NEXT:   tf_device.return
+  "tf_device.cluster"() ( {
+    %0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
+    %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
+    "tf.C"(%1) : (tensor<i32>) -> ()
+    tf_device.return
+  }) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
+func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
+  // CHECK:      %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
+  // CHECK:        %[[A_OUT:.*]] = "tf.A"
+  // CHECK:        %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
+  // CHECK-NEXT:   tf_device.return %[[D_OUT]]
+  //
+  // CHECK:      "tf_device.cluster"
+  // CHECK:        "tf.B"
+  // CHECK:        "tf.C"
+  // CHECK:        "tf.E"
+  // CHECK-NEXT:   tf_device.return
+  "tf_device.cluster"() ( {
+    %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
+    %1 = "tf.B"() {} : () -> (tensor<i32>)
+    %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
+    %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
+    %4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
     tf_device.return
   }) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
   return
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
index af0119d..b8a48bb 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
@@ -1222,6 +1222,41 @@
 
 // -----
 
+// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
+
+module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
+  // CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
+  func @replicated_parallel_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
+    %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+    // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+    %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+      // CHECK: "tf._TPUCompileMlir"
+      // CHECK: "tf.TPUCompileSucceededAssert"
+      // CHECK: "tf_device.parallel_execute"
+      // CHECK:    "tf.TPUExecute"
+      %3 = "tf_device.parallel_execute"() ( {
+        "tf.D"() : () -> ()
+        tf_device.return
+      }, {
+        %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
+
+        tf_device.return %4 : tensor<?xi32>
+      }) : () -> (tensor<?xi32>)
+      tf_device.return %3 : tensor<?xi32>
+    }
+    %2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
+    return %2 : tensor<?xi32>
+  }
+
+  func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    %0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+    return %0 : tensor<?xi32>
+  }
+}
+
+// -----
+
 // Tests devices are set properly for non replicated model parallelism.
 
 module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index c1d99c2..0b1ff2b 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -258,7 +258,7 @@
 
 // Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
 // at head/tail of TPU cluster to run before/after TPU computation.
-std::unique_ptr<OperationPass<FuncOp>>
+std::unique_ptr<OperationPass<ModuleOp>>
 CreateTPUExtractHeadTailOutsideCompilationPass();
 
 // Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index efe82c4..5a2cae3 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -66,8 +66,7 @@
 namespace mlir {
 namespace TF {
 namespace {
-Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
-    FuncOp func) {
+Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
   // Find any return ops.
   SmallVector<ReturnOp, 4> return_ops;
   for (Block& block : func) {
@@ -137,9 +136,9 @@
       cast_op = b.create<TF::CastOp>(op->getLoc(), old_type, result,
                                      /*truncate=*/b.getBoolAttr(false));
     }
-    return mlir::Value(cast_op);
+    return Value(cast_op);
   };
-  for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
+  for (OpOperand& use : make_early_inc_range(result.getUses())) {
     if (use.getOwner()->getDialect() != tf_dialect &&
         !IsSupportedNonTFOp(use.getOwner()))
       use.set(get_cast_op());
@@ -162,7 +161,7 @@
 bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
                                  Operation* op, Dialect* tf_dialect) {
   bool changed = false;
-  for (auto entry : llvm::zip(pass_through_operands, op->getResults())) {
+  for (auto entry : zip(pass_through_operands, op->getResults())) {
     Type operand_type = std::get<0>(entry).getType();
     Value result = std::get<1>(entry);
     if (result.getType() == operand_type) continue;
@@ -204,7 +203,7 @@
         tf_dialect);
   }
   // TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp.
-  if (auto tensor_cast = dyn_cast<mlir::TensorCastOp>(op)) {
+  if (auto tensor_cast = dyn_cast<TensorCastOp>(op)) {
     return InferShapeForPassThroughOps(
         tensor_cast.getOperation()->getOperands(), op, tf_dialect);
   }
@@ -254,7 +253,7 @@
 // match the i-th operand type). Returns true if anything is changed.
 bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
   bool changed = false;
-  for (auto entry : llvm::zip(operands, results)) {
+  for (auto entry : zip(operands, results)) {
     Type operand_type = std::get<0>(entry).getType();
     Type result_type = std::get<1>(entry).getType();
     if (operand_type == result_type) continue;
@@ -291,14 +290,13 @@
   CallInterfaceCallable callable = call_op.getCallableForCallee();
   SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>();
   if (!sym) return false;
-  FuncOp func =
-      dyn_cast<mlir::FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
+  FuncOp func = dyn_cast<FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
   if (!func) return false;
 
   bool changed = false;
   // Map each of the results of the call to the returned type of the
   // function.
-  for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) {
+  for (auto result : zip(op->getResults(), func.getType().getResults())) {
     if (std::get<0>(result).getType() == std::get<1>(result)) continue;
     // Skip already statically shaped results.
     if (!CanBeRefined(std::get<0>(result).getType())) continue;
@@ -335,7 +333,7 @@
   // Map each of the results of the call to the returned type of the
   // function.
   bool changed = false;
-  for (auto result : llvm::zip(op->getResults(), inferred)) {
+  for (auto result : zip(op->getResults(), inferred)) {
     if (std::get<0>(result).getType() == std::get<1>(result)) continue;
 
     // Inserts a cast back to the original type if any user is not in the
@@ -356,7 +354,7 @@
 // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
 // scalar value).
 struct ValuePort {
-  llvm::PointerUnion<Operation*, BlockArgument> producer;
+  PointerUnion<Operation*, BlockArgument> producer;
   SmallVector<unsigned int, 2> port;
 
   bool operator==(const ValuePort& other) const {
@@ -374,39 +372,38 @@
       port = {0};
     }
   }
-  ValuePort(llvm::PointerUnion<Operation*, BlockArgument> producer,
+  ValuePort(PointerUnion<Operation*, BlockArgument> producer,
             SmallVector<unsigned int, 2> port)
       : producer(producer), port(port) {}
 
-  llvm::raw_ostream& print(llvm::raw_ostream& os) const {
+  raw_ostream& print(raw_ostream& os) const {
     if (auto* op = producer.dyn_cast<Operation*>())
       os << "op " << op->getName();
     if (auto ba = producer.dyn_cast<BlockArgument>())
       os << "block_arg " << ba.getArgNumber();
-    os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
+    os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
     return os;
   }
 };
 
 struct ValuePortHasher {
   std::size_t operator()(const ValuePort& other) const {
-    return llvm::hash_combine(
-        llvm::hash_value(other.producer.getOpaqueValue()),
-        llvm::hash_value(ArrayRef<unsigned int>(other.port)));
+    return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
+                        hash_value(ArrayRef<unsigned int>(other.port)));
   }
 };
 
 using ValuePortResultMap =
     std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
-using ComputedQueryFn = llvm::function_ref<bool(ValuePort)>;
-using ValueQueryFn = llvm::function_ref<Attribute(const ValuePort&)>;
-using ValuePortInputs = llvm::SmallVectorImpl<ValuePort>;
+using ComputedQueryFn = function_ref<bool(ValuePort)>;
+using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
+using ValuePortInputs = SmallVectorImpl<ValuePort>;
 
-// TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are
+// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
 // intended to be switched to op interfaces once more refined.
-LogicalResult InputsRequiredForOutput(ValuePort value_port,
-                                      ComputedQueryFn has_been_computed,
-                                      ValuePortInputs* inputs) {
+LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
+                                             ComputedQueryFn has_been_computed,
+                                             ValuePortInputs* inputs) {
   auto op = value_port.producer.dyn_cast<Operation*>();
   auto& port = value_port.port;
   if (!op) return failure();
@@ -460,26 +457,94 @@
   return nullptr;
 }
 
-ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
+// Context used during ShapeInference. This class contains common information
+// that is required by the individual shape inference helper functions (e.g.,
+// TF Graph version, constant values computed, etc.)
+class ShapeInference {
+ public:
+  ShapeInference(int64_t graph_version, MLIRContext* context);
+
+  LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
+                                               ValuePortInputs* inputs) {
+    return ::mlir::TF::ComputeInputsRequiredForOutput(
+        value_port,
+        [this](const ValuePort& port) {
+          return results_.find(port) != results_.end();
+        },
+        inputs);
+  }
+
+  Attribute ComputeOutputComponent(const ValuePort& value_port) {
+    return ::mlir::TF::ComputeOutputComponent(
+        value_port, [this](const ValuePort& port) { return results_[port]; });
+  }
+
+  // Returns ShapeHandle if the op result could be computed as shape.
+  ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
+
+  void RecordValue(const ValuePort& value_port, Attribute value) {
+    results_[value_port] = value;
+  }
+
+  // Performs shape inference on the provided op and return true if the type of
+  // at least one result has been changed.
+  // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
+  // `graph_version` indicates the current GraphDef compatibility versions
+  // (the versions field in graph.proto).
+  bool InferShapeForSingleOperation(Operation* op);
+
+  // Infers shape on the provided region, including nested ones, iterate until
+  // fix point with a limit of max_iteration. Returns success if fix point is
+  // reached before max_iteration.
+  LogicalResult InferShapeUntilFixPoint(Region* region,
+                                        int64_t max_iteration = 10);
+
+  // Updates input types and refine shapes inside body of functions that are
+  // attached to ControlFlow ops (If/While). These functions include Then/Else
+  // branches of IfOp and Cond/Body functions of WhileOp. These functions share
+  // following common properties:
+  //   1) They are never reused, ie. having a single use in module.
+  //   2) Their input types match those of their parent ops (excluding inputs
+  //      like predicate).
+  // Returns a boolean indicating whether any change has been applied.
+  LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
+                                              ArrayRef<Type> input_types,
+                                              int64_t max_iteration);
+
+  // Propagate the shapes to the functions named.
+  LogicalResult PropagateShapeToFunctions(
+      ModuleOp module, Operation::operand_type_range input_types,
+      ArrayRef<StringRef> func_names, int64_t max_iteration);
+
+  // Shape propagation for call/control flow ops.
+  LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
+                                                    int64_t max_iteration);
+
+ private:
+  // Mapping between ValuePort (which corresponds to an OpResult or smaller,
+  // e.g., first element of OpResult produded) to an Attribute if the ValuePort
+  // corresponds to a constant value.
+  ValuePortResultMap results_;
+  int64_t graph_version_;
+  MLIRContext* context_;
+  Dialect* tf_dialect_;
+};
+
+ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
+    : graph_version_(graph_version) {
+  context_ = context;
+  tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
+}
+
+ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
+                                                 InferenceContext* ic) {
   LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
   auto rt = result.getType().dyn_cast<RankedTensorType>();
   if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
   int dim_size = rt.getDimSize(0);
 
   // Worklist to direct partial evaluation.
-  llvm::SmallVector<ValuePort, 4> worklist;
-  // The ValuePort evaluated results.
-  // TODO(jpienaar): This could be cached across invocations (e.g., part of some
-  // inference context).
-  ValuePortResultMap evaluated;
-  // Returns whether a ValuePort has been previously computed.
-  auto has_been_computed = [&evaluated](const ValuePort& port) {
-    return evaluated.find(port) != evaluated.end();
-  };
-  // Returns previously computed ValuePort value.
-  auto values = [&evaluated](const ValuePort& port) -> Attribute {
-    return evaluated[port];
-  };
+  SmallVector<ValuePort, 4> worklist;
 
   // Simple evaluator that attempts to partially evaluate the input value even
   // if unable to evaluate the complete output. Below follows a simple stack
@@ -498,7 +563,7 @@
       LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front "));
 
       SmallVector<ValuePort, 4> inputs;
-      auto res = InputsRequiredForOutput(front, has_been_computed, &inputs);
+      auto res = ComputeInputsRequiredForOutput(front, &inputs);
       if (failed(res)) {
         // Abort if unable to find which required inputs need to be computed.
         worklist.clear();
@@ -513,16 +578,16 @@
         continue;
       }
 
-      auto ret = ComputeOutputComponent(front, values);
+      auto ret = ComputeOutputComponent(front);
       if (!ret) continue;
 
-      evaluated[front] = ret;
+      RecordValue(front, ret);
       LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
 
       // If worklist is empty, then this is the root query op.
       if (worklist.empty()) {
         LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
-        if (auto dea = ret.dyn_cast<mlir::DenseIntElementsAttr>()) {
+        if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
           if (dea.getNumElements() != 1) {
             LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n");
             return {};
@@ -536,9 +601,8 @@
   return ic->MakeShape(dims);
 }
 
-bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
-                                  int64_t graph_version) {
-  assert(tf_dialect == op->getDialect());
+bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
+  assert(tf_dialect_ == op->getDialect());
   // The shape function of these ops sometimes does not propagate subtypes
   // (handle shapes) for resource and variant types. We use a simple passthrough
   // to make sure they are preserved in the output.
@@ -550,7 +614,7 @@
   // If no result for this op needs shape inference, we have a fast-path return.
   // But if the type is a resource/variant, we do not skip it because we might
   // not have the handle shapes.
-  if (llvm::none_of(op->getResultTypes(), CanBeRefined)) {
+  if (none_of(op->getResultTypes(), CanBeRefined)) {
     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
                             << op->getName() << "'.\n");
     return false;
@@ -565,8 +629,8 @@
   // This is necessary to avoid reprocessing the tf.Cast that are inserted at
   // the end of this function.
   if (isa<CastOp>(op) &&
-      llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) {
-        return user->getDialect() != tf_dialect;
+      all_of(op->getResult(0).getUsers(), [&](Operation* user) {
+        return user->getDialect() != tf_dialect_;
       })) {
     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
                                "dialect operation users '"
@@ -646,7 +710,7 @@
   // Perform the shape inference using an InferenceContext with the input
   // shapes. This object is abstracting the information that the ShapeInference
   // function operates on.
-  InferenceContext c(graph_version, *node_def, op_reg_data->op_def,
+  InferenceContext c(graph_version_, *node_def, op_reg_data->op_def,
                      input_shapes, input_tensors,
                      /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
   auto status = c.Run(op_reg_data->shape_inference_fn);
@@ -659,7 +723,7 @@
   // Determine if, during shape computation, the shape functions attempted to
   // query an input operand as shape where the input was not known/constant.
   bool requires_inputs =
-      llvm::any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
+      any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
         return c.requested_input_tensor_as_partial_shape(input) &&
                !input_tensors[input];
       });
@@ -723,7 +787,7 @@
         new_element_type.isa<TF::VariantType>()) {
       auto handle_shapes_types = c.output_handle_shapes_and_types(output);
       if (handle_shapes_types) {
-        llvm::SmallVector<mlir::TensorType, 1> subtypes;
+        SmallVector<TensorType, 1> subtypes;
         OpBuilder b(op);
         for (const auto& shape_n_type : *handle_shapes_types) {
           Type element_type;
@@ -743,7 +807,7 @@
     if (result.getType() == new_type) continue;
     // Inserts a cast back to the original type if any user is not in the TF
     // dialect.
-    AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
+    AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_,
                                        result.getType());
     // Finally we inferred the shape and replace the type for this result.
     result.setType(new_type);
@@ -755,23 +819,13 @@
   return changed;
 }
 
-// Updates input types and refine shapes inside body of functions that are
-// attached to ControlFlow ops (If/While). These functions include Then/Else
-// branches of IfOp and Cond/Body functions of WhileOp. These functions share
-// following common properties:
-//   1) They are never reused, ie. having a single use in module.
-//   2) Their input types match those of their parent ops (excluding inputs like
-//      predicate).
-// Returns a boolean indicating whether any change has been applied.
-LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
-                                            llvm::ArrayRef<Type> input_types,
-                                            int64_t graph_version,
-                                            int64_t max_iteration) {
+LogicalResult ShapeInference::RefineShapeForControlFlowFunc(
+    FuncOp func, ArrayRef<Type> input_types, int64_t max_iteration) {
   ModuleOp module = func.getParentOfType<ModuleOp>();
   auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
   int num_uses = std::distance(func_uses->begin(), func_uses->end());
   if (num_uses != 1) {
-    func.emitWarning(llvm::formatv(
+    func.emitWarning(formatv(
         "expected control flow function {0} to have exactly 1 use, found {1}.",
         func.getName(), num_uses));
     return failure();
@@ -785,8 +839,7 @@
     arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
   }
 
-  auto res =
-      InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration);
+  auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration);
   if (failed(res)) return res;
 
   auto new_return_types = InferShapeForFunctionReturnType(func);
@@ -798,20 +851,18 @@
   return success();
 }
 
-LogicalResult PropagateShapeToFunctions(
+LogicalResult ShapeInference::PropagateShapeToFunctions(
     ModuleOp module, Operation::operand_type_range input_types,
-    llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
-    int64_t max_iteration) {
-  bool success = true;
+    ArrayRef<StringRef> func_names, int64_t max_iteration) {
+  bool all_succeeded = true;
   auto types = llvm::to_vector<4>(input_types);
   for (auto func_name : func_names) {
     FuncOp func = module.lookupSymbol<FuncOp>(func_name);
-    if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
-                                             max_iteration))) {
-      success = false;
-    }
+    all_succeeded =
+        succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) &&
+        all_succeeded;
   }
-  return mlir::success(success);
+  return success(all_succeeded);
 }
 
 // If the callee has only one use, propagates any constant operand of call_op to
@@ -831,7 +882,7 @@
     // the constant inside the function.
     for (auto arg : func.getArguments()) {
       auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
-      if (llvm::isa_and_nonnull<TF::ConstOp>(operand)) {
+      if (isa_and_nonnull<TF::ConstOp>(operand)) {
         arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
       }
     }
@@ -850,33 +901,31 @@
   for (auto retval :
        llvm::enumerate(func.front().getTerminator()->getOperands())) {
     auto retval_op = retval.value().getDefiningOp();
-    if (llvm::isa_and_nonnull<TF::ConstOp>(retval_op)) {
+    if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
       op->getResult(retval.index())
           .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
     }
   }
 }
 
-LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
-                                                  int64_t graph_version,
-                                                  int64_t max_iteration) {
+LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
+    Operation* op, int64_t max_iteration) {
   ModuleOp module = op->getParentOfType<ModuleOp>();
   if (auto if_op = dyn_cast<TF::IfOp>(op)) {
     return PropagateShapeToFunctions(
-        module, llvm::drop_begin(if_op.getOperandTypes(), 1),
-        {if_op.then_branch(), if_op.else_branch()}, graph_version,
-        max_iteration);
+        module, drop_begin(if_op.getOperandTypes(), 1),
+        {if_op.then_branch(), if_op.else_branch()}, max_iteration);
   } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
     return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
                                      {while_op.cond(), while_op.body()},
-                                     graph_version, max_iteration);
+                                     max_iteration);
   } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
     CallInterfaceCallable callable = call_op.getCallableForCallee();
     if (SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>()) {
       PropagateConstantToCallee(call_op, sym, module);
       if (failed(PropagateShapeToFunctions(
               module, call_op.getArgOperands().getTypes(),
-              {sym.getRootReference()}, graph_version, max_iteration))) {
+              {sym.getRootReference()}, max_iteration))) {
         return failure();
       }
       PropagateConstantFromCallee(call_op, sym, module);
@@ -889,13 +938,10 @@
   return success();
 }
 
-LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
-                                      int64_t max_iteration) {
-  MLIRContext* ctx = region->getContext();
-  Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>();
-
-  // An operation folder that is used to attempt folding before inference.
-  OperationFolder folder(ctx);
+LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
+                                                      int64_t max_iteration) {
+  // An operation folder that is used to attempt folding before inference._
+  OperationFolder folder(context_);
   bool changed = true;
 
   // TODO(aminim): we could have a more efficient traversal by guiding the
@@ -908,14 +954,14 @@
                << "Shape inference, iteration " << iteration << "\n");
     region->walk([&](Operation* op) {
       if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
-        changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect);
+        changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
         // TODO(jpienaar): Debug why we can't just return here. We end up with
         // additional constant due to the propagation of constant into attached
         // function if we return already.
       }
 
-      if (op->getDialect() != tf_dialect) {
-        changed |= InferShapeForNonTFDialectOperation(op, tf_dialect);
+      if (op->getDialect() != tf_dialect_) {
+        changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_);
         return;
       }
 
@@ -924,13 +970,12 @@
 
       // Best-effort shape inference in attached functions. Do not return
       // failure even if it doesn't get to fixed point.
-      if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version,
-                                                     max_iteration))) {
+      if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
         op->emitWarning() << "unable to refine shape of attached function "
                              "arguments and bodies";
       }
 
-      changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version);
+      changed |= InferShapeForSingleOperation(op);
     });
   }
 
@@ -945,31 +990,43 @@
 LogicalResult InferShapeForFunction(FuncOp func,
                                     ArrayRef<ArrayRef<int64_t>> arg_shapes,
                                     int64_t graph_version) {
-  mlir::FunctionType func_type = func.getType();
+  ShapeInference context(graph_version, func.getContext());
+  if (arg_shapes.empty()) {
+    if (failed(context.InferShapeUntilFixPoint(&func.getBody())))
+      return failure();
+    // TODO(b/156276510): Verify that it is always fine to refine a function's
+    // return type, as long as we do not change the argument shapes.
+    if (auto return_types = InferShapeForFunctionReturnType(func)) {
+      func.setType(FunctionType::get(func.getType().getInputs(),
+                                     return_types.getValue(),
+                                     func.getContext()));
+    }
+
+    return success();
+  }
+  FunctionType func_type = func.getType();
   bool needs_refinement = false;
-  llvm::SmallVector<mlir::Type, 4> new_arg_types;
+  SmallVector<Type, 4> new_arg_types;
   new_arg_types.reserve(func_type.getNumInputs());
 
   // Update argument types in-place using the provided arg_shapes.
   for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
     ArrayRef<int64_t> shape = arg_shapes[i];
-    mlir::Type element_type;
-    if (auto input_ty =
-            func_type.getInput(i).dyn_cast<mlir::RankedTensorType>()) {
+    Type element_type;
+    if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
       if (!input_ty || input_ty.getShape().size() != shape.size()) {
         return failure();
       }
       element_type = input_ty.getElementType();
     } else {
-      auto unranked_input_ty =
-          func_type.getInput(i).dyn_cast<mlir::TensorType>();
+      auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
       if (!unranked_input_ty) {
         return failure();
       }
       element_type = unranked_input_ty.getElementType();
     }
 
-    auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
+    auto new_arg_type = RankedTensorType::get(shape, element_type);
     if (new_arg_type != func_type.getInput(i)) {
       // If the new type is more detailed, trigger shape inference.
       func.getArgument(i).setType(new_arg_type);
@@ -982,28 +1039,17 @@
     return success();
   }
 
-  mlir::LogicalResult result =
-      mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version);
+  LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody());
   if (failed(result)) {
     return failure();
   }
 
   auto return_types = InferShapeForFunctionReturnType(func);
-  func.setType(mlir::FunctionType::get(new_arg_types,
-                                       return_types.hasValue()
-                                           ? return_types.getValue()
-                                           : func.getType().getResults(),
-                                       func.getContext()));
-
-  return success();
-}
-
-LogicalResult InferShapeForFunctionType(FuncOp func) {
-  if (auto return_types = InferShapeForFunctionReturnType(func)) {
-    func.setType(mlir::FunctionType::get(func.getType().getInputs(),
-                                         return_types.getValue(),
-                                         func.getContext()));
-  }
+  func.setType(FunctionType::get(new_arg_types,
+                                 return_types.hasValue()
+                                     ? return_types.getValue()
+                                     : func.getType().getResults(),
+                                 func.getContext()));
 
   return success();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
index 0524ec6..e36d8d5 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
@@ -27,30 +27,13 @@
 
 namespace TF {
 
-// Performs shape inference on the provided op and return true if the type of
-// at least one result has been changed.
-// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
-// `graph_version` indicates the current GraphDef compatibility versions
-// (the versions field in graph.proto).
-bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
-                                  int64_t graph_version);
-
-// Infers shape on the provided region, including nested ones, iterate until fix
-// point with a limit of max_iteration. Returns success if fix point is reached
-// before max_iteration.
-LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
-                                      int64_t max_iteration = 10);
-
 // Given a list of refined shapes matching the function arguments of func, runs
 // shape inference over the function to propagate this updated information.
+// If arg_shapes are empty, then argument shapes will be left unchanged.
 LogicalResult InferShapeForFunction(FuncOp func,
                                     ArrayRef<ArrayRef<int64_t>> arg_shapes,
                                     int64_t graph_version);
 
-// Refines the return type of the given function by folding tf.Cast that
-// precedes the return instruction.
-LogicalResult InferShapeForFunctionType(FuncOp func);
-
 }  // namespace TF
 
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
index 48e4e77..acdfc0e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
@@ -58,10 +58,8 @@
     }
     int64_t producer = producer_or.ValueOrDie();
     for (auto func : module.getOps<FuncOp>()) {
-      InferShapeUntilFixPoint(&func.getBody(), producer);
-      // TODO(yuanzx): Verify that it is always fine to refine a function's
-      // return type, as long as we do not change the argument shapes.
-      InferShapeForFunctionType(func);
+      if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer)))
+        return signalPassFailure();
     }
   }
 };
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
index 141feeb..b9e2144 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
@@ -14,11 +14,23 @@
 ==============================================================================*/
 
 #include <memory>
+#include <type_traits>
 
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Block.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
 
 namespace mlir {
 namespace TFTPU {
@@ -30,30 +42,182 @@
 
 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
 
-struct TPUExtractHeadTailOutsideCompilation
-    : public PassWrapper<TPUExtractHeadTailOutsideCompilation, FunctionPass> {
-  void runOnFunction() override;
-};
+bool HasOutsideCompilationAttribute(Operation* op) {
+  return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
+}
 
-void TPUExtractHeadTailOutsideCompilation::runOnFunction() {
-  getFunction().walk([&](tf_device::LaunchOp launch) {
-    Block& launch_block = launch.GetBody();
-    for (auto& op : llvm::make_early_inc_range(launch_block.getOperations())) {
-      // TODO(b/155115766): Handle outputs that should be inputs to TPU
-      // LaunchOp.
-      if (auto attr =
-              op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
-        op.moveBefore(launch);
-      } else {
+// Returns whether all operands of `op` are from values inside the
+// `input_value_set`.
+bool OpContainsOperandsFromSet(Operation* op,
+                               const llvm::SetVector<Value>& input_value_set) {
+  for (auto operand : op->getOperands())
+    if (input_value_set.count(operand) == 0) return false;
+
+  return true;
+}
+
+void RecordOutsideCompiledOpsAndUsages(
+    Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
+    llvm::SetVector<Value>* outside_compiled_op_usages) {
+  if (HasOutsideCompilationAttribute(op) &&
+      OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
+    outside_compiled_ops->insert(op);
+    outside_compiled_op_usages->insert(op->getResults().begin(),
+                                       op->getResults().end());
+  }
+}
+
+// Traverses the MLIR graph and returns a set of ops that
+// are connected to inputs of TPU computation and outside compiled.
+void ExtractOutsideCompiledOpsConnectedToHead(
+    Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
+    llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
+  llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
+  for (auto& usage : input_value.getUses()) {
+    auto head_operation = usage.getOwner();
+    RecordOutsideCompiledOpsAndUsages(head_operation,
+                                      &parent_outside_compiled_ops_at_head,
+                                      values_used_in_host_cluster);
+  }
+
+  // Traverse the graph and find all outside compiled ops connected from
+  // the `input_value`.
+  while (!parent_outside_compiled_ops_at_head.empty()) {
+    llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
+    for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
+      auto op_results = head_outside_compiled_op->getOpResults();
+      for (auto op_result : op_results) {
+        for (auto& use : op_result.getUses()) {
+          auto connected_op = use.getOwner();
+          RecordOutsideCompiledOpsAndUsages(connected_op,
+                                            &connected_outside_compiled_ops,
+                                            values_used_in_host_cluster);
+        }
+      }
+    }
+
+    outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
+                                 parent_outside_compiled_ops_at_head.end());
+    std::swap(parent_outside_compiled_ops_at_head,
+              connected_outside_compiled_ops);
+  }
+}
+
+// TODO(hongjunchoi): Also handle ops without inputs that are outside
+// compiled.
+//
+// Returns set of ops that are outside compiled and are directly connected
+// to inputs to the TPU computation.
+llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
+    tf_device::ClusterOp tpu_cluster) {
+  llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
+  llvm::SetVector<Value> values_used_in_cluster;
+  auto& cluster_region = tpu_cluster.body();
+  getUsedValuesDefinedAbove(cluster_region, cluster_region,
+                            values_used_in_cluster);
+
+  auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
+  for (auto input_value : input_value_list)
+    ExtractOutsideCompiledOpsConnectedToHead(
+        input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
+  return outside_compiled_at_head_ops;
+}
+
+// Returns output values of extracted outside compiled cluster at head that
+// are used by the TPU computation.
+llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
+    const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
+  llvm::SmallVector<Value, 8> outputs;
+  outputs.reserve(head_outside_compiled_ops.size());
+
+  for (auto op : head_outside_compiled_ops) {
+    for (Operation* user : op->getUsers()) {
+      if (!head_outside_compiled_ops.count(user)) {
+        outputs.append(op->result_begin(), op->result_end());
         break;
       }
     }
+  }
+
+  return outputs;
+}
+
+// Creates new tf_device.launch op with outside compiled ops extracted
+// from the head of TPU computation.
+llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
+    OpBuilder* builder, tf_device::ClusterOp cluster,
+    const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
+  if (head_outside_compiled_ops.empty())
+    return llvm::Optional<tf_device::LaunchOp>();
+
+  // Create tf_device.launch op to separate all extracted outside compiled ops
+  // before the tf_device.cluster.
+  auto output_values =
+      GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
+
+  llvm::SmallVector<Type, 8> output_return_types;
+  output_return_types.reserve(output_values.size());
+  for (auto output : output_values)
+    output_return_types.emplace_back(output.getType());
+
+  builder->setInsertionPoint(cluster);
+  auto host_launch_op = builder->create<tf_device::LaunchOp>(
+      cluster.getLoc(), builder->getStringAttr(""), output_return_types);
+
+  // Replace all usages of outside compiled ops that are used in TPU
+  // computation with the results of the above created launch op.
+  for (auto output_and_index : llvm::enumerate(output_values)) {
+    auto output_index = output_and_index.index();
+    auto output = output_and_index.value();
+    for (auto& use : output.getUses()) {
+      if (!head_outside_compiled_ops.count(use.getOwner()))
+        use.set(host_launch_op.getResult(output_index));
+    }
+  }
+
+  // Create terminator op for the newly created launch op.
+  host_launch_op.body().push_back(new Block());
+  builder->setInsertionPointToEnd(&host_launch_op.GetBody());
+  auto terminator = builder->create<tf_device::ReturnOp>(
+      host_launch_op.getLoc(), output_values);
+
+  // Move all outside compile ops from cluster op to launch op.
+  for (auto outside_compiled_op : head_outside_compiled_ops)
+    outside_compiled_op->moveBefore(terminator);
+
+  return host_launch_op;
+}
+
+struct TPUExtractHeadTailOutsideCompilation
+    : public PassWrapper<TPUExtractHeadTailOutsideCompilation,
+                         OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+};
+
+void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
+  // Get runtime devices information from the closest parent module.
+  auto module = getOperation();
+  mlir::TF::RuntimeDevices devices;
+  if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
+    return signalPassFailure();
+
+  OpBuilder builder(&getContext());
+  module.walk([&](tf_device::ClusterOp cluster) {
+    auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
+    IsolateHeadExtractedOpsToLaunchOp(&builder, cluster,
+                                      head_outside_compiled_ops);
+
+    // TODO(b/156030523): Update device attribute of newly created host launch
+    // op as well as enclosing Replicate op (if TPU computation is replicated)
+    // with host device names.
+
+    // TODO(b/155115766): Implement tail outside compiled op extraction.
   });
 }
 
 }  // anonymous namespace
 
-std::unique_ptr<OperationPass<FuncOp>>
+std::unique_ptr<OperationPass<ModuleOp>>
 CreateTPUExtractHeadTailOutsideCompilationPass() {
   return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 98ff0de..f5e9da9 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -92,7 +92,7 @@
 //
 // Would become following ops (unimportant attributes, types are omitted):
 //    %1 = "tf.Shape"(%0)
-//    %2:2 = "tf.MLIRCompileToTPU"(%1) {module = "<Serialized @tpu_func>"}
+//    %2:2 = "tf._TPUCompileMlir"(%1) {module = "<Serialized @tpu_func>"}
 //    "tf.TPUCompileSucceededAssert"(%2#0)
 //    %3 = "tf.TPUExecute"(%0, %2#1)
 //    %4 = "tf.SomeOp"(%3)
@@ -448,19 +448,20 @@
 // core, and all replica devices per core are grouped together.
 void AssignDevicesToReplicate(
     tf_device::ReplicateOp replicate,
-    llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
+    llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
+        tpu_devices,
     OpBuilder* builder) {
   if (!replicate) return;
 
-  const int num_replicas = execution_devices.size();
-  const int num_cores_per_replica = execution_devices.front().size();
+  const int num_replicas = tpu_devices.size();
+  const int num_cores_per_replica = tpu_devices.front().size();
 
   llvm::SmallVector<NamedAttribute, 8> device_attrs;
   for (int core = 0; core < num_cores_per_replica; ++core) {
     llvm::SmallVector<StringRef, 8> devices_by_core;
     devices_by_core.reserve(num_replicas);
     for (int replica = 0; replica < num_replicas; ++replica)
-      devices_by_core.push_back(execution_devices[replica][core]);
+      devices_by_core.push_back(tpu_devices[replica][core].device);
 
     device_attrs.push_back(
         builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
@@ -492,11 +493,12 @@
 // Creates a tf_device.parallel_execute op that wraps TPUExecute op to
 // represent execution of TPU program in multiple logical cores.
 LogicalResult BuildParallelExecuteOp(
-    llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
+    llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
+        tpu_devices,
     llvm::ArrayRef<xla::OpSharding> output_sharding_config,
     Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
     OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
-  const int num_cores_per_replica = execution_devices.front().size();
+  const int num_cores_per_replica = tpu_devices.front().size();
   // parallel_execute op returns concatenated list of return values of
   // all its regions.
   //
@@ -528,7 +530,7 @@
       num_cores_per_replica, cluster_func, builder, &input_list);
   if (failed(result)) return failure();
 
-  const bool replicated = execution_devices.size() != 1;
+  const bool replicated = tpu_devices.size() != 1;
   // For each logical core, create a region with TPUExecute op.
   assert(input_list.size() == num_cores_per_replica);
   for (int core = 0; core < num_cores_per_replica; ++core) {
@@ -553,7 +555,7 @@
     // op.
     std::string device = replicated
                              ? tensorflow::GetDeviceAliasForLogicalCore(core)
-                             : execution_devices.front()[core];
+                             : tpu_devices.front()[core].device;
 
     auto region_launch_op =
         WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
@@ -566,13 +568,14 @@
 }
 
 tf_device::LaunchOp AssignDevicesToReplicatedExecute(
-    llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
+    llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
+        tpu_devices,
     Operation* execute_op, OpBuilder* builder) {
-  const bool replicated = execution_devices.size() != 1;
+  const bool replicated = tpu_devices.size() != 1;
   // If computation is replicated, use aliased device. Otherwise there is only
   // one execution device and the device is assigned to the execute op.
   std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
-                                  : execution_devices.front().front();
+                                  : tpu_devices.front().front().device;
 
   return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
 }
@@ -687,6 +690,16 @@
   // Create compile op.
   auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
   builder->setInsertionPoint(cluster_func);
+
+  // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
+  // parallel_execute region if it exists.
+  if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
+    // Currently, outside compilation and model parallelism are not supported
+    // together.
+    assert(num_cores_per_replica == 1);
+    builder->setInsertionPoint(cluster_func.getParentOp());
+  }
+
   Operation* compile_op = BuildCompileOp(
       cluster_func, num_replicas, num_cores_per_replica,
       tpu_device_assignment.compilation_device,
@@ -704,7 +717,7 @@
   BuildTPUCompileSucceededAssertOp(
       compile_op, tpu_device_assignment.compilation_device, builder);
 
-  AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices,
+  AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
                            builder);
 
   llvm::SmallVector<xla::OpSharding, 4> output_shardings;
@@ -712,12 +725,13 @@
       num_cores_per_replica, cluster_func, &output_shardings);
   if (failed(result)) return failure();
 
+  builder->setInsertionPoint(cluster_func);
   if (num_cores_per_replica > 1) {
     // For model parallelism, tf_device.parallel_execute is used to express
     // concurrent device execution across multiple logical devices.
 
     tf_device::ParallelExecuteOp execute_op;
-    result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices,
+    result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
                                     output_shardings, compile_op, cluster_func,
                                     builder, &execute_op);
     if (failed(result)) return failure();
@@ -740,7 +754,7 @@
     if (failed(result)) return failure();
 
     tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
-        tpu_device_assignment.execution_devices, execute_op, builder);
+        tpu_device_assignment.tpu_devices, execute_op, builder);
     cluster_func.replaceAllUsesWith(launch_op);
   }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 2374687..e8ca691 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -293,6 +293,12 @@
   tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
   tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
 
+  // Run shape inference pass to propagate shapes through tensor_cast operations
+  // from static to dynamic shapes. This could be generated if the shape
+  // inference was originally missing in a TF op but the corresponding HLO op
+  // had static shape after lowering.
+  tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
+
   // Run LegalizeTFPass again because the previous legalization passes can
   // expose more graph pruning and canonicalization opportunities that are
   // necessary for the second LegalizeTFPass(allow_partial_conversion=false)
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
index fcfef56..b28f26b 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
@@ -31,12 +31,14 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
+#include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/tstring.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
@@ -131,13 +133,21 @@
   case DTYPE:                      \
     return ConvertFlatTensor<CTYPE>(input_tensor, type);
 
-  // TODO(fengliuai): customize the conversions for more types.
+  // TODO(fengliuai): customize the conversions for quantized and string types.
   switch (input_dtype) {
     CONVERT_FLAT(DT_BOOL, bool)
     CONVERT_FLAT(DT_FLOAT, float)
     CONVERT_FLAT(DT_DOUBLE, double)
+    CONVERT_FLAT(DT_INT8, int8)
+    CONVERT_FLAT(DT_INT16, int16)
     CONVERT_FLAT(DT_INT32, int32)
     CONVERT_FLAT(DT_INT64, int64)
+    CONVERT_FLAT(DT_UINT8, uint8)
+    CONVERT_FLAT(DT_UINT16, uint16)
+    CONVERT_FLAT(DT_UINT32, uint32)
+    CONVERT_FLAT(DT_UINT64, uint64)
+    CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
+    CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
 
     // BFLOAT16 is a special case that it needs to be cast to double type to
     // match its storage type.
@@ -207,12 +217,20 @@
 
 // Converts an MLIR dense string elements attribute to a TensorFlow tensor
 // proto.
-Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
-                                 TensorProto* output_tensor) {
-  for (const auto& val : attr.getRawStringData()) {
-    output_tensor->add_string_val(val.data(), val.size());
+void ConvertStringElementsAttr(
+    const DenseStringElementsAttr attr,
+    protobuf::RepeatedPtrField<std::string>* output) {
+  for (const auto& val : attr.getRawStringData())
+    output->Add({val.data(), val.size()});
+}
+
+template <typename T>
+void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
+                                protobuf::RepeatedField<T>* output) {
+  for (const auto& val : attr.getValues<std::complex<T>>()) {
+    output->Add(val.real());
+    output->Add(val.imag());
   }
-  return Status::OK();
 }
 
 // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
@@ -226,139 +244,80 @@
   return InvalidArgument("Unexpected elements attribute type from MLIR.");
 }
 
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the double_val field updated.
-Status ConvertDoubleElementsAttr(const ElementsAttr attr,
-                                 TensorProto* output_tensor) {
-  if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
-    if (elts.isSplat()) {
-      output_tensor->add_double_val(elts.getSplatValue<double>());
-    } else {
-      for (auto value : elts.getValues<double>())
-        output_tensor->add_double_val(value);
-    }
-    return Status::OK();
-  }
-  return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the float_val field updated.
-Status ConvertFloatElementsAttr(const ElementsAttr attr,
-                                TensorProto* output_tensor) {
-  if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
-    if (elts.isSplat()) {
-      output_tensor->add_float_val(elts.getSplatValue<float>());
-    } else {
-      for (auto value : elts.getValues<float>())
-        output_tensor->add_float_val(value);
-    }
-    return Status::OK();
-  }
-  return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the half_val field updated.
-Status ConvertHalfElementsAttr(const ElementsAttr attr,
-                               TensorProto* output_tensor) {
-  if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
-    if (elts.isSplat()) {
-      output_tensor->add_half_val(
-          (*elts.begin()).bitcastToAPInt().getSExtValue());
-    } else {
-      for (const auto& value : elts.getFloatValues())
-        output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
-    }
-    return Status::OK();
-  }
-  return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the int_val field updated.
-Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
-                              TensorProto* output_tensor) {
-  if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
-    if (elts.isSplat()) {
-      output_tensor->add_int_val((*elts.begin()).getSExtValue());
-    } else {
-      for (const auto& val : elts)
-        output_tensor->add_int_val(val.getSExtValue());
-    }
-    return Status::OK();
-  }
-  return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
-                                   TensorProto* output_tensor) {
-  auto elts = attr.dyn_cast<DenseFPElementsAttr>();
-  if (!elts) {
-    return ConvertOpaqueElementsAttr(attr, output_tensor);
-  }
-
-  // Bfloat16 is internally represented as `double` in MLIR.
-  if (elts.isSplat()) {
-    double v = elts.getSplatValue<double>();
-    bfloat16 bf16_val = static_cast<bfloat16>(v);
-    output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
+// Converts an MLIR elements attribute and adds it to specified repeated field.
+template <typename T>
+void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
+                         protobuf::RepeatedField<T>* output) {
+  if (attr.isSplat()) {
+    output->Add(attr.getSplatValue<T>());
   } else {
-    for (auto v : elts.getValues<double>()) {
+    for (auto value : attr.getValues<T>()) output->Add(value);
+  }
+}
+
+// Converts an MLIR elements attribute containing half values and adds it to
+// specified repeated field.
+void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
+                             protobuf::RepeatedField<int>* output_tensor) {
+  if (attr.isSplat()) {
+    output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
+  } else {
+    for (const llvm::APFloat value : attr.getFloatValues())
+      output_tensor->Add(value.bitcastToAPInt().getSExtValue());
+  }
+}
+
+// Converts an MLIR elements attribute containing int values and adds it to
+// specified repeated field.
+void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
+                            protobuf::RepeatedField<int>* output) {
+  if (attr.isSplat()) {
+    output->Add((*attr.begin()).getSExtValue());
+  } else {
+    for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
+  }
+}
+
+void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
+                                 protobuf::RepeatedField<int>* output) {
+  // Bfloat16 is internally represented as `double` in MLIR.
+  if (attr.isSplat()) {
+    double v = attr.getSplatValue<double>();
+    bfloat16 bf16_val = static_cast<bfloat16>(v);
+    output->Add(absl::bit_cast<int16>(bf16_val));
+  } else {
+    for (auto v : attr.getValues<double>()) {
       bfloat16 bf16_val = static_cast<bfloat16>(v);
-      output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
+      output->Add(absl::bit_cast<int16>(bf16_val));
     }
   }
-
-  return Status::OK();
 }
 
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the int64_val field updated.
-Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
-                                TensorProto* output_tensor) {
-  if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
-    if (elts.isSplat()) {
-      output_tensor->add_int64_val((*elts.begin()).getSExtValue());
-    } else {
-      for (const auto& val : elts)
-        output_tensor->add_int64_val(val.getSExtValue());
-    }
-    return Status::OK();
-  }
-  return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with bool_val field updated.
-Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
-                               TensorProto* output_tensor) {
-  if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
-    for (const auto& val : elts) {
-      output_tensor->add_bool_val(val.getBoolValue());
-    }
-    return Status::OK();
-  }
-  return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-Status ConvertToTensorProto(const ElementsAttr attr,
-                            TensorProto* output_tensor) {
+Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
   auto type = attr.getType();
   auto shape = type.getShape();
   DataType output_dtype;
   TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
-  output_tensor->set_dtype(output_dtype);
-  ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
+  output->set_dtype(output_dtype);
+  ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
+
+  if (attr.isa<OpaqueElementsAttr>())
+    return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
+
+  auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
+  if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
 
   switch (output_dtype) {
     case DT_FLOAT:
-      return ConvertFloatElementsAttr(attr, output_tensor);
+      ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
+      break;
     case DT_HALF:
-      // Handles both DenseFPElementsAttr and OpaqueElementsAttr.
-      return ConvertHalfElementsAttr(attr, output_tensor);
+      ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
+                              output->mutable_half_val());
+      break;
     case DT_DOUBLE:
-      return ConvertDoubleElementsAttr(attr, output_tensor);
+      ConvertElementsAttr(dense_attr, output->mutable_double_val());
+      break;
     case DT_QUINT8:
     case DT_UINT8:
     case DT_INT8:
@@ -366,20 +325,40 @@
     case DT_UINT16:
     case DT_INT16:
     case DT_INT32:
-      return ConvertIntElementsAttr(attr, output_tensor);
+      ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
+                             output->mutable_int_val());
+      break;
+    case DT_UINT32:
+      ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
+      break;
+    case DT_UINT64:
+      ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
+      break;
     case DT_INT64:
-      return ConvertInt64ElementsAttr(attr, output_tensor);
+      ConvertElementsAttr(dense_attr, output->mutable_int64_val());
+      break;
     case DT_BOOL:
-      return ConvertBoolElementsAttr(attr, output_tensor);
+      ConvertElementsAttr(dense_attr, output->mutable_bool_val());
+      break;
     case DT_BFLOAT16:
-      return ConvertBfloat16ElementsAttr(attr, output_tensor);
+      ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
+                                  output->mutable_half_val());
+      break;
     case DT_STRING:
-      return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
-                                       output_tensor);
+      ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
+                                output->mutable_string_val());
+      break;
+    case DT_COMPLEX64:
+      ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
+      break;
+    case DT_COMPLEX128:
+      ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
+      break;
     default:
-      return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
-                                       output_tensor);
+      return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
+                                                DataTypeString(output_dtype)));
   }
+  return Status::OK();
 }
 
 Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
index d711c19..bf96e3d 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
 
 #include <cstring>
+#include <initializer_list>
 
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
@@ -99,48 +100,74 @@
   EXPECT_EQ(string_values[3], mlir::StringRef("four"));
 }
 
-TEST(ConvertTypeToTensorTypeTest, Convert16BitFloats) {
+class ConvertTensorTest : public ::testing::Test {
+ protected:
+  template <typename T>
+  void VerifyConversion(std::initializer_list<T> values, DataType dtype,
+                        mlir::Type expected_ty) {
+    mlir::Builder b(expected_ty.getContext());
+    Tensor tensor(dtype, TensorShape({static_cast<int64>(values.size())}));
+    tensor.flat<T>().setValues(values);
+
+    auto value_or = ConvertTensor(tensor, &b);
+    TF_ASSERT_OK(value_or.status());
+    auto attr = value_or.ValueOrDie();
+
+    EXPECT_EQ(attr.getType().getElementType(), expected_ty);
+
+    Tensor out;
+    TF_ASSERT_OK(ConvertToTensor(attr, &out));
+
+    test::ExpectTensorEqual<T>(tensor, out);
+  }
+};
+
+TEST_F(ConvertTensorTest, Simple) {
   RegisterDialects();
+
   mlir::MLIRContext context;
-  mlir::Builder b(&context);
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
+      {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
+  ASSERT_NO_FATAL_FAILURE(
+      VerifyConversion<bfloat16>({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16,
+                                 mlir::FloatType::getBF16(&context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<float>(
+      {1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<double>(
+      {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
 
-  {
-    // Create the sample tensor to convert.
-    Tensor tensor(DT_HALF, TensorShape({1}));
-    auto Tt = tensor.flat<Eigen::half>();
-    Tt.setValues({Eigen::half(1.0)});
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
+      {1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
+      {1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
+      {1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
+      {1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
 
-    auto value_or = ConvertTensor(tensor, &b);
-    TF_EXPECT_OK(value_or.status());
-    auto attr = value_or.ValueOrDie();
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
+      {1, 2}, DT_UINT8,
+      mlir::IntegerType::get(
+          8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
+      {1, 2}, DT_UINT16,
+      mlir::IntegerType::get(
+          16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
+      {1, 2}, DT_UINT32,
+      mlir::IntegerType::get(
+          32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
+      {1, 2}, DT_UINT64,
+      mlir::IntegerType::get(
+          64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
 
-    EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
-    EXPECT_TRUE(attr.getType().getElementType().isF16());
-
-    Tensor out;
-    TF_ASSERT_OK(ConvertToTensor(attr, &out));
-
-    test::ExpectTensorEqual<Eigen::half>(tensor, out);
-  }
-
-  {
-    // Create the sample tensor to convert.
-    Tensor tensor(DT_BFLOAT16, TensorShape({2}));
-    auto Tt = tensor.flat<bfloat16>();
-    Tt.setValues({bfloat16(1.0), bfloat16(-1.0)});
-
-    auto value_or = ConvertTensor(tensor, &b);
-    TF_EXPECT_OK(value_or.status());
-    auto attr = value_or.ValueOrDie();
-
-    EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
-    EXPECT_TRUE(attr.getType().getElementType().isBF16());
-
-    Tensor out;
-    TF_ASSERT_OK(ConvertToTensor(attr, &out));
-
-    test::ExpectTensorEqual<bfloat16>(tensor, out);
-  }
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
+      {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
+      mlir::ComplexType::get(mlir::FloatType::getF32(&context))));
+  ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<double>>(
+      {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128,
+      mlir::ComplexType::get(mlir::FloatType::getF64(&context))));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
index cc79525..4877cbc 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
@@ -59,6 +59,18 @@
 
 namespace tensorflow {
 namespace {
+// static TensorFlow op prefix set.
+std::set<std::string>* GlobalOpPrefixes() {
+  static std::set<std::string>* global_op_prefixes = [] {
+    std::set<std::string>* result = new std::set<std::string>;
+    result->insert("tf.");
+    result->insert("_tf.");
+    result->insert("tf_executor.");
+    return result;
+  }();
+  return global_op_prefixes;
+}
+
 // Converts a location to the debug information for the node def.
 Status ConvertLocation(mlir::Location inst_loc,
                        NodeDef::ExperimentalDebugInfo* debug_info) {
@@ -268,8 +280,10 @@
   // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
   // don't need to consider ".source"/".Source" because the nodes with this
   // suffix are skipped by the caller and will not be added to the graph.
-  if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
-      !op_name.consume_front("tf_executor.")) {
+  auto prefixes = GlobalOpPrefixes();
+  if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
+        return op_name.consume_front(prefix);
+      })) {
     return errors::FailedPrecondition("op node '", op_name.str(),
                                       "' was not a TF op!");
   }
@@ -506,4 +520,9 @@
          inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
 }
 
+Status AddTensorFlowOpPrefix(std::string prefix) {
+  GlobalOpPrefixes()->insert(prefix);
+  return Status::OK();
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h
index 32ed528..58fe39f 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h
@@ -34,10 +34,17 @@
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 
+namespace mlir {
+class ShapedType;
+}  // namespace mlir
+
 namespace tensorflow {
 
 using stream_executor::port::StatusOr;
 
+// Add custom op prefix for TensorFlow dialects.
+Status AddTensorFlowOpPrefix(std::string);
+
 // Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control
 // dialect back into a TensorFlow valid op name.
 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
index ddbcc91..06c10c2 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
@@ -164,12 +164,19 @@
   return DeviceNameUtils::ParsedNameToString(system_device);
 }
 
+// Finds the host CPU device for a given TPU device.
+std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
+  tpu_device.type = DEVICE_CPU;
+  tpu_device.id = 0;
+  return DeviceNameUtils::ParsedNameToString(tpu_device);
+}
+
 // Determines execution devices when topology and device assignment are not
 // defined. This is a special case where a single core computation is replicated
 // to every core in the mesh. TPU devices are simply added to
 // `execution_devices` of one replica. `num_replicas` must be 1 or the total
 // number of TPU devices available, and `num_cores_per_replica` must be 1.
-StatusOr<ExecutionDevices> GetFullMeshTPUExecutionDeviceAssignment(
+StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
     int num_replicas, int num_cores_per_replica,
     llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
   const int num_tasks = tpu_devices.size();
@@ -185,17 +192,18 @@
         "'num_cores_per_replica' must be equal to 1, got ",
         num_cores_per_replica);
 
-  ExecutionDevices execution_devices;
-  execution_devices.reserve(num_replicas);
+  TPUDevicesAndHosts devices_and_hosts;
+  devices_and_hosts.reserve(num_replicas);
   for (int i = 0; i < num_replicas; ++i) {
     const int task = i / num_tpus_per_task;
     const int device = i % num_tpus_per_task;
-    execution_devices.push_back(
-        {tensorflow::DeviceNameUtils::ParsedNameToString(
-            tpu_devices[task][device])});
+    const auto& tpu_device = tpu_devices[task][device];
+    devices_and_hosts.push_back({TPUDeviceAndHost(
+        /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
+        /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
   }
 
-  return execution_devices;
+  return devices_and_hosts;
 }
 
 // Helper struct for keeping track of task and device for an associated TPU
@@ -326,7 +334,7 @@
 //  - number of device coordinates (in tuple 3) match number 'num_replicas' *
 //    'num_cores_per_replica'
 //  - a TPU device associated with each device coordinate
-StatusOr<std::pair<ExecutionDevices, xla::DeviceAssignmentProto>>
+StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
 GetGeneralTPUExecutionDeviceAssignment(
     int num_replicas, int num_cores_per_replica,
     llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
@@ -361,9 +369,9 @@
   std::vector<bool> used_device_ids(
       location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1),
       false);
-  ExecutionDevices execution_devices(
-      num_replicas,
-      llvm::SmallVector<std::string, 8>(num_cores_per_replica, ""));
+  TPUDevicesAndHosts devices_and_hosts(
+      num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
+                        num_cores_per_replica, TPUDeviceAndHost()));
   xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
   int pos = 0;
   for (int replica = 0; replica < num_replicas; ++replica) {
@@ -393,16 +401,18 @@
 
       used_device_ids[device_id] = true;
       device_assignment(replica, logical_core) = device_id;
-      execution_devices[replica][logical_core] =
-          DeviceNameUtils::ParsedNameToString(tpu_devices[task][device]);
+      auto& device_and_host = devices_and_hosts[replica][logical_core];
+      const auto& tpu_device = tpu_devices[task][device];
+      device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
+      device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
     }
   }
 
   xla::DeviceAssignmentProto device_assignment_proto;
   TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
 
-  return std::pair<ExecutionDevices, xla::DeviceAssignmentProto>(
-      std::move(execution_devices), std::move(device_assignment_proto));
+  return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
+      std::move(devices_and_hosts), std::move(device_assignment_proto));
 }
 
 }  // anonymous namespace
@@ -447,27 +457,4 @@
   return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
 }
 
-StatusOr<std::string> GetCPUHostForTPUDevice(llvm::StringRef tpu_device) {
-  Device device;
-  if (!DeviceNameUtils::ParseFullName(tpu_device.str(), &device))
-    return errors::InvalidArgument("'", tpu_device.str(),
-                                   "' is not a valid device");
-
-  device.type = DEVICE_CPU;
-  device.id = 0;
-  return DeviceNameUtils::ParsedNameToString(device);
-}
-
-StatusOr<llvm::SmallVector<std::string, 8>> GetCPUHostsForTPUDevices(
-    llvm::ArrayRef<std::string> tpu_devices) {
-  llvm::SmallVector<std::string, 8> cpu_devices;
-  cpu_devices.reserve(tpu_devices.size());
-  for (const auto& tpu_device : tpu_devices) {
-    TF_ASSIGN_OR_RETURN(cpu_devices.emplace_back(),
-                        GetCPUHostForTPUDevice(tpu_device));
-  }
-
-  return cpu_devices;
-}
-
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h
index 47ce7f1..5fdb6b8 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h
@@ -30,29 +30,40 @@
 namespace tensorflow {
 using stream_executor::port::StatusOr;
 
-// TPU devices to be used for execution (e.g. devices for TPUExecute ops). They
-// are ordered by `num_replicas` followed by `num_cores_per_replica`.
-using ExecutionDevices =
-    llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8>;
+// A TPU device for execution alongside its associated host CPU device.
+struct TPUDeviceAndHost {
+  TPUDeviceAndHost() {}
+  TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host)
+      : device(device), host(host) {}
 
-// TPU compilation device, execution devices, and optionally execution device
-// IDs. Execution device IDs are populated if `topology` and `device_assignment`
-// are provided.
+  std::string device;
+  std::string host;
+};
+
+// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and
+// their associated host CPU devices (for outside compilation). They are ordered
+// by `num_replicas` followed by `num_cores_per_replica`.
+using TPUDevicesAndHosts =
+    llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>;
+
+// TPU compilation device, execution and associated host devices, and optionally
+// execution device IDs. Execution device IDs are populated if `topology` and
+// `device_assignment` are provided.
 struct TPUDeviceAssignment {
   TPUDeviceAssignment(llvm::StringRef compilation_device,
-                      ExecutionDevices&& execution_devices)
+                      TPUDevicesAndHosts&& tpu_devices)
       : compilation_device(compilation_device),
-        execution_devices(std::move(execution_devices)) {}
+        tpu_devices(std::move(tpu_devices)) {}
 
   TPUDeviceAssignment(llvm::StringRef compilation_device,
-                      ExecutionDevices&& execution_devices,
+                      TPUDevicesAndHosts&& tpu_devices,
                       xla::DeviceAssignmentProto&& xla_device_assignment)
       : compilation_device(compilation_device),
-        execution_devices(std::move(execution_devices)),
+        tpu_devices(std::move(tpu_devices)),
         xla_device_assignment(std::move(xla_device_assignment)) {}
 
   std::string compilation_device;
-  ExecutionDevices execution_devices;
+  TPUDevicesAndHosts tpu_devices;
   llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
 };
 
@@ -216,17 +227,6 @@
 // logical core.
 std::string GetDeviceAliasForLogicalCore(int core_index);
 
-// Finds associated CPU host device for given TPU device. This assumes a
-// matching CPU host device exists based on TPU device name. An error will be
-// returned if the TPU device name is invalid.
-StatusOr<std::string> GetCPUHostForTPUDevice(llvm::StringRef tpu_device);
-
-// Finds associated CPU host devices for given TPU devices. This assumes a
-// matching CPU host device exist based on each TPU device name. An error will
-// be returned if a TPU device name is invalid.
-StatusOr<llvm::SmallVector<std::string, 8>> GetCPUHostsForTPUDevices(
-    llvm::ArrayRef<std::string> tpu_devices);
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
index 57e123a..7ac5635a 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
@@ -323,30 +323,46 @@
 
   TF_ASSERT_OK(status_or.status());
 
-  auto& tpu_device_assignment = status_or.ValueOrDie();
+  const auto& tpu_device_assignment = status_or.ValueOrDie();
   EXPECT_EQ(tpu_device_assignment.compilation_device,
             "/job:worker/replica:0/task:0/device:CPU:0");
-  auto& execution_devices = tpu_device_assignment.execution_devices;
-  ASSERT_EQ(execution_devices.size(), 8);
-  for (const auto& replica_execution_device : execution_devices)
-    ASSERT_EQ(replica_execution_device.size(), 1);
+  const auto& tpu_devices = tpu_device_assignment.tpu_devices;
+  ASSERT_EQ(tpu_devices.size(), 8);
+  for (const auto& replica_tpu_devices : tpu_devices)
+    ASSERT_EQ(replica_tpu_devices.size(), 1);
 
-  EXPECT_EQ(execution_devices[0][0],
+  EXPECT_EQ(tpu_devices[0][0].device,
             "/job:worker/replica:0/task:0/device:TPU:0");
-  EXPECT_EQ(execution_devices[1][0],
+  EXPECT_EQ(tpu_devices[0][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[1][0].device,
             "/job:worker/replica:0/task:0/device:TPU:1");
-  EXPECT_EQ(execution_devices[2][0],
+  EXPECT_EQ(tpu_devices[1][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[2][0].device,
             "/job:worker/replica:0/task:0/device:TPU:2");
-  EXPECT_EQ(execution_devices[3][0],
+  EXPECT_EQ(tpu_devices[2][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[3][0].device,
             "/job:worker/replica:0/task:0/device:TPU:3");
-  EXPECT_EQ(execution_devices[4][0],
+  EXPECT_EQ(tpu_devices[3][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[4][0].device,
             "/job:worker/replica:0/task:1/device:TPU:0");
-  EXPECT_EQ(execution_devices[5][0],
+  EXPECT_EQ(tpu_devices[4][0].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[5][0].device,
             "/job:worker/replica:0/task:1/device:TPU:1");
-  EXPECT_EQ(execution_devices[6][0],
+  EXPECT_EQ(tpu_devices[5][0].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[6][0].device,
             "/job:worker/replica:0/task:1/device:TPU:2");
-  EXPECT_EQ(execution_devices[7][0],
+  EXPECT_EQ(tpu_devices[6][0].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[7][0].device,
             "/job:worker/replica:0/task:1/device:TPU:3");
+  EXPECT_EQ(tpu_devices[7][0].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
 
   EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
 }
@@ -410,30 +426,46 @@
 
   TF_ASSERT_OK(status_or.status());
 
-  auto& tpu_device_assignment = status_or.ValueOrDie();
+  const auto& tpu_device_assignment = status_or.ValueOrDie();
   EXPECT_EQ(tpu_device_assignment.compilation_device,
             "/job:worker/replica:0/task:0/device:CPU:0");
-  auto& execution_devices = tpu_device_assignment.execution_devices;
-  ASSERT_EQ(execution_devices.size(), 4);
-  for (const auto& replica_execution_device : execution_devices)
-    ASSERT_EQ(replica_execution_device.size(), 2);
+  const auto& tpu_devices = tpu_device_assignment.tpu_devices;
+  ASSERT_EQ(tpu_devices.size(), 4);
+  for (const auto& replica_tpu_devices : tpu_devices)
+    ASSERT_EQ(replica_tpu_devices.size(), 2);
 
-  EXPECT_EQ(execution_devices[0][0],
+  EXPECT_EQ(tpu_devices[0][0].device,
             "/job:worker/replica:0/task:0/device:TPU:0");
-  EXPECT_EQ(execution_devices[0][1],
+  EXPECT_EQ(tpu_devices[0][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[0][1].device,
             "/job:worker/replica:0/task:1/device:TPU:3");
-  EXPECT_EQ(execution_devices[1][0],
+  EXPECT_EQ(tpu_devices[0][1].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[1][0].device,
             "/job:worker/replica:0/task:0/device:TPU:1");
-  EXPECT_EQ(execution_devices[1][1],
+  EXPECT_EQ(tpu_devices[1][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[1][1].device,
             "/job:worker/replica:0/task:1/device:TPU:2");
-  EXPECT_EQ(execution_devices[2][0],
+  EXPECT_EQ(tpu_devices[1][1].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[2][0].device,
             "/job:worker/replica:0/task:0/device:TPU:3");
-  EXPECT_EQ(execution_devices[2][1],
+  EXPECT_EQ(tpu_devices[2][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[2][1].device,
             "/job:worker/replica:0/task:1/device:TPU:0");
-  EXPECT_EQ(execution_devices[3][0],
+  EXPECT_EQ(tpu_devices[2][1].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[3][0].device,
             "/job:worker/replica:0/task:0/device:TPU:2");
-  EXPECT_EQ(execution_devices[3][1],
+  EXPECT_EQ(tpu_devices[3][0].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[3][1].device,
             "/job:worker/replica:0/task:1/device:TPU:1");
+  EXPECT_EQ(tpu_devices[3][1].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
 
   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
   ASSERT_TRUE(xla_device_assignment.hasValue());
@@ -511,23 +543,35 @@
   EXPECT_EQ(tpu_device_assignment.compilation_device,
             "/job:worker/replica:0/task:0/device:CPU:0");
 
-  auto& execution_devices = tpu_device_assignment.execution_devices;
-  ASSERT_EQ(execution_devices.size(), 2);
-  for (const auto& replica_execution_device : execution_devices)
-    ASSERT_EQ(replica_execution_device.size(), 3);
+  auto& tpu_devices = tpu_device_assignment.tpu_devices;
+  ASSERT_EQ(tpu_devices.size(), 2);
+  for (const auto& replica_tpu_devices : tpu_devices)
+    ASSERT_EQ(replica_tpu_devices.size(), 3);
 
-  EXPECT_EQ(execution_devices[0][0],
+  EXPECT_EQ(tpu_devices[0][0].device,
             "/job:worker/replica:0/task:1/device:TPU:1");
-  EXPECT_EQ(execution_devices[0][1],
+  EXPECT_EQ(tpu_devices[0][0].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[0][1].device,
             "/job:worker/replica:0/task:1/device:TPU:0");
-  EXPECT_EQ(execution_devices[0][2],
+  EXPECT_EQ(tpu_devices[0][1].host,
+            "/job:worker/replica:0/task:1/device:CPU:0");
+  EXPECT_EQ(tpu_devices[0][2].device,
             "/job:worker/replica:0/task:2/device:TPU:0");
-  EXPECT_EQ(execution_devices[1][0],
+  EXPECT_EQ(tpu_devices[0][2].host,
+            "/job:worker/replica:0/task:2/device:CPU:0");
+  EXPECT_EQ(tpu_devices[1][0].device,
             "/job:worker/replica:0/task:2/device:TPU:1");
-  EXPECT_EQ(execution_devices[1][1],
+  EXPECT_EQ(tpu_devices[1][0].host,
+            "/job:worker/replica:0/task:2/device:CPU:0");
+  EXPECT_EQ(tpu_devices[1][1].device,
             "/job:worker/replica:0/task:0/device:TPU:0");
-  EXPECT_EQ(execution_devices[1][2],
+  EXPECT_EQ(tpu_devices[1][1].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
+  EXPECT_EQ(tpu_devices[1][2].device,
             "/job:worker/replica:0/task:0/device:TPU:1");
+  EXPECT_EQ(tpu_devices[1][2].host,
+            "/job:worker/replica:0/task:0/device:CPU:0");
 
   auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
   ASSERT_TRUE(xla_device_assignment.hasValue());
@@ -552,44 +596,5 @@
   EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
 }
 
-struct ParameterizedCPUHostForTPUDeviceTest
-    : ::testing::TestWithParam<std::tuple<std::string, std::string>> {};
-
-TEST_P(ParameterizedCPUHostForTPUDeviceTest, CPUHostForTPUDevice) {
-  auto status_or_device = GetCPUHostForTPUDevice(std::get<0>(GetParam()));
-  TF_ASSERT_OK(status_or_device.status());
-  EXPECT_EQ(status_or_device.ValueOrDie(), std::get<1>(GetParam()));
-}
-
-INSTANTIATE_TEST_SUITE_P(
-    CPUHostForTPUDevice, ParameterizedCPUHostForTPUDeviceTest,
-    ::testing::Values(
-        std::make_tuple("/job:worker/replica:0/task:0/device:TPU:0",
-                        "/job:worker/replica:0/task:0/device:CPU:0"),
-        std::make_tuple("/job:worker/replica:0/task:1/device:TPU:1",
-                        "/job:worker/replica:0/task:1/device:CPU:0")));
-
-TEST(TPURewriteDeviceUtilTest, CPUHostForTPUDeviceInvalidDevice) {
-  auto status_or_device = GetCPUHostForTPUDevice("bad_device");
-  ASSERT_FALSE(status_or_device.ok());
-}
-
-TEST(TPURewriteDeviceUtilTest, CPUHostsForTPUDevices) {
-  auto status_or_devices =
-      GetCPUHostsForTPUDevices({"/job:worker/replica:0/task:0/device:TPU:0",
-                                "/job:worker/replica:0/task:1/device:TPU:1"});
-  TF_ASSERT_OK(status_or_devices.status());
-  const auto& devices = status_or_devices.ValueOrDie();
-  ASSERT_EQ(devices.size(), 2);
-  EXPECT_EQ(devices[0], "/job:worker/replica:0/task:0/device:CPU:0");
-  EXPECT_EQ(devices[1], "/job:worker/replica:0/task:1/device:CPU:0");
-}
-
-TEST(TPURewriteDeviceUtilTest, CPUHostsForTPUDevicesInvalidDevice) {
-  auto status_or_devices = GetCPUHostsForTPUDevices(
-      {"/job:worker/replica:0/task:0/device:TPU:0", "bad_device"});
-  ASSERT_FALSE(status_or_devices.ok());
-}
-
 }  // anonymous namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD
index 9b731d2..806a77e 100644
--- a/tensorflow/compiler/mlir/tfjs/BUILD
+++ b/tensorflow/compiler/mlir/tfjs/BUILD
@@ -1,4 +1,5 @@
 load("//third_party/mlir:tblgen.bzl", "gentbl")
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -131,10 +132,106 @@
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
-        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
-        "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Transforms",
     ],
 )
+
+cc_library(
+    name = "json_translate_lib",
+    srcs = [
+        "translate/json_translate.cc",
+    ],
+    hdrs = [
+        "translate/json_translate.h",
+    ],
+    deps = [
+        ":tensorflow_js",
+        ":tensorflow_js_dialect_registration",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+        "//tensorflow/compiler/mlir/tensorflow:export_utils",
+        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/status",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Translation",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "tf_to_tfjs_json",
+    srcs = ["translate/tf_to_tfjs_json.cc"],
+    hdrs = [
+        "translate/tf_to_tfjs_json.h",
+    ],
+    deps = [
+        ":json_translate_lib",
+        ":tfjs_optimize",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
+        "//tensorflow/compiler/mlir/tensorflow:error_util",
+        "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
+        "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
+        "//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
+        "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:support",
+        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+    ],
+    alwayslink = 1,
+)
+
+tf_cc_binary(
+    name = "json_translate",
+    deps = [
+        ":json_translate_lib",
+        "@llvm-project//mlir:MlirTranslateMain",
+    ],
+)
+
+filegroup(
+    name = "tf_tfjs_translate_main",
+    srcs = [
+        "translate/tf_tfjs_translate.cc",
+    ],
+)
+
+tf_cc_binary(
+    name = "tf_tfjs_translate",
+    srcs = [":tf_tfjs_translate_main"],
+    deps = [
+        ":json_translate_lib",
+        ":tensorflow_js_passes",
+        ":tf_to_tfjs_json",
+        ":tfjs_optimize",
+        "//tensorflow/compiler/mlir:init_mlir",
+        "//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/platform:errors",
+        "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:Support",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
index 318895d..545183a 100644
--- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
+++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
@@ -28,6 +28,7 @@
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Interfaces/SideEffects.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
+
 namespace mlir {
 namespace tfjs {
 
diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
new file mode 100644
index 0000000..5c8d37d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
@@ -0,0 +1,23 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+licenses(["notice"])
+
+glob_lit_tests(
+    data = [
+        ":test_utilities",
+    ],
+    driver = "@llvm-project//mlir:run_lit.sh",
+    test_file_exts = [
+        "pbtxt",
+    ],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+    name = "test_utilities",
+    testonly = True,
+    data = [
+        "//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate",
+        "@llvm-project//llvm:FileCheck",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
new file mode 100644
index 0000000..f6a324f
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
@@ -0,0 +1,78 @@
+# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure
+# Add two tensor<4xi32> inputs and return the result
+
+node {
+  name: "Add"
+  op: "Add"
+  input: "input0"
+  input: "input1"
+  attr {
+    key: "T"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+node {
+  name: "input0"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+node {
+  name: "input1"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+node {
+  name: "Mul"
+  op: "Mul"
+  input: "Add"
+  input: "Add"
+  attr {
+    key: "T"
+    value {
+      type: DT_INT32
+    }
+  }
+}
+versions {
+  producer: 27
+}
+
+# CHECK: "name": "input0"
+# CHECK-NEXT: "op": "Placeholder"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "input1",
+# CHECK-NEXT: "op": "Placeholder"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "Add"
+# CHECK-NEXT: "op": "AddV2"
+# CHECK-NEXT: "input":
+# CHECK-NEXT: "input0"
+# CHECK-NEXT: "input1"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "Mul1"
+# CHECK-NEXT: "op": "Mul"
+# CHECK-NEXT: "input":
+# CHECK-NEXT: "Add"
+# CHECK-NEXT: "Add"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "Mul"
+# CHECK-NEXT: "op": "_Retval"
+# CHECK-NEXT: "input":
+# CHECK-NEXT: "Mul1"
+# CHECK: "type": "DT_INT32"
+# CHECK: "library"
+# CHECK: "versions"
+# CHECK: "producer": 27
+
diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
new file mode 100644
index 0000000..810db71
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
@@ -0,0 +1,175 @@
+# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure
+# Add two tensor<4xi32> inputs and return the result
+
+node {
+  name: "input0"
+  op: "Placeholder"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "shape"
+    value {
+      shape {
+        dim {
+          size: 10
+        }
+      }
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "alpha"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+        }
+        float_val: 0.5
+      }
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "Relu"
+  op: "Relu"
+  input: "input0"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "Neg"
+  op: "Neg"
+  input: "input0"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "Relu1"
+  op: "Relu"
+  input: "Neg"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "Mul"
+  op: "Mul"
+  input: "alpha"
+  input: "Relu1"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "Add"
+  op: "Add"
+  input: "Relu"
+  input: "Mul"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  experimental_debug_info {
+  }
+}
+node {
+  name: "main"
+  op: "_Retval"
+  input: "Add"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "index"
+    value {
+      i: 0
+    }
+  }
+}
+library {
+}
+versions {
+  producer: 344
+}
+
+# CHECK: "node":
+# CHECK: "name": "input0",
+# CHECK-NEXT: "op": "Placeholder",
+# CHECK-NEXT: "attr":
+# CHECK: "type": "DT_FLOAT"
+# CHECK: "name": "Add.Relu.Neg.Relu1.Mul",
+# CHECK-NEXT: "op": "Const",
+# CHECK-NEXT: "attr":
+# CHECK: "value":
+# CHECK: "tensor":
+# CHECK: "dtype": "DT_FLOAT",
+# CHECK: "tensorShape": {},
+# CHECK: "floatVal":
+# CHECK: -0.5
+# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1",
+# CHECK-NEXT: "op": "Prelu",
+# CHECK-NEXT: "input":
+# CHECK: "input0",
+# CHECK: "Add.Relu.Neg.Relu1.Mul"
+# CHECK: "attr":
+# CHECK: "_output_shapes":
+# CHECK: "list":
+# CHECK: "shape":
+# CHECK: "dim":
+# CHECK: "size": "10"
+# CHECK: "experimentalDebugInfo": {}
+# CHECK: "name": "Add",
+# CHECK-NEXT: "op": "_Retval",
+# CHECK-NEXT: "input":
+# CHECK: "Add.Relu.Neg.Relu1.Mul1"
+# CHECK: "attr":
+# CHECK: "T":
+# CHECK: "type": "DT_FLOAT"
+# CHECK: "library": {},
+# CHECK: "versions":
+# CHECK: "producer": 344
+
diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
index 631bb1a..a445937 100644
--- a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
+++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -20,7 +20,6 @@
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
 
@@ -47,6 +46,11 @@
   // Canonicalize, CSE etc.
   pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
   pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
+
+  // raise to executor dialect in order to use GraphDef converter
+  pm->addNestedPass<mlir::FuncOp>(
+      mlir::CreateFunctionalToExecutorDialectConversionPass());
+  pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass());
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
new file mode 100644
index 0000000..7f4b8ff
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
@@ -0,0 +1,105 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "mlir/Translation.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
+
+using mlir::ModuleOp;
+using mlir::TranslateFromMLIRRegistration;
+using std::string;
+using tensorflow::Status;
+using xla::StatusOr;
+
+// Translates the given MLIR module in the TFJS dialect to TFJS JSON
+// format. Returns false on success.
+//
+bool tfjs::MlirToJSONTranslateFunction(ModuleOp module,
+                                       std::string* serialized_json) {
+  string json_output;
+  // Allow TF to treat TFJS ops as TF ops.
+  if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) {
+    LOG(ERROR) << "Failed to add tfjs op prefix.";
+    return false;
+  }
+  tensorflow::GraphExportConfig confs;
+  confs.export_shapes = true;
+  confs.export_library = true;
+  tensorflow::FunctionLibraryDefinition flib_def(
+      tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
+  absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;
+  auto graph = absl::make_unique<tensorflow::Graph>(flib_def);
+  auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def,
+                                               &control_ret_nodes);
+  if (!status.ok()) {
+    LOG(ERROR) << "Graph export failed: " << status;
+    return false;
+  }
+  auto graphdef = absl::make_unique<tensorflow::GraphDef>();
+  graph->ToGraphDef(graphdef.get());
+
+  // Replace the _Arg nodes of the main function with Placeholder op.
+  auto nodes = graphdef->mutable_node();
+  for (const auto& node : llvm::enumerate(*nodes)) {
+    if (node.value().op() == "_Arg") {
+      nodes->Mutable(node.index())->set_op("Placeholder");
+    }
+  }
+
+  tensorflow::protobuf::util::JsonPrintOptions json_options;
+  json_options.add_whitespace = true;
+  auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString(
+      *graphdef, &json_output, json_options);
+  if (!jsonStatus.ok()) {
+    LOG(ERROR) << "Proto2Json failed: " << status;
+    return false;
+  }
+  *serialized_json = std::move(json_output);
+  return true;
+}
+
+static mlir::LogicalResult MlirToJSONFileTranslateFunction(
+    ModuleOp module, llvm::raw_ostream& output) {
+  std::string serialized_json;
+  if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json))
+    return mlir::failure();
+
+  output << serialized_json;
+  return mlir::success();
+}
+
+static TranslateFromMLIRRegistration MLIRToJSONFileTranslate(
+    "mlir-to-tfjs-json", MlirToJSONFileTranslateFunction);
diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.h b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h
new file mode 100644
index 0000000..0a931f7
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
+#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
+
+#include <string>
+
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tfjs {
+
+// Translates the given MLIR `module` into a JSON string. Returns true if
+// translation fails, otherwise returns false.
+bool MlirToJSONTranslateFunction(mlir::ModuleOp module,
+                                 std::string* serialized_json);
+}  // namespace tfjs
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
new file mode 100644
index 0000000..e735a3c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
@@ -0,0 +1,173 @@
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include <string>
+
+#include "absl/strings/str_split.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/IR/Diagnostics.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Support/FileUtilities.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/init_mlir.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
+#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
+#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+using llvm::cl::opt;
+using mlir::MLIRContext;
+using stream_executor::port::StatusOr;
+
+// NOLINTNEXTLINE
+opt<std::string> input_file_name(llvm::cl::Positional,
+                                 llvm::cl::desc("<input file>"),
+                                 llvm::cl::init("-"));
+
+// NOLINTNEXTLINE
+opt<bool> import_saved_model_object_graph(
+    "savedmodel-objectgraph-to-mlir",
+    llvm::cl::desc("Import a saved model to its MLIR representation"),
+    llvm::cl::value_desc("dir"));
+
+// NOLINTNEXTLINE
+opt<bool> import_saved_model_signature_defs(
+    "savedmodel-signaturedefs-to-mlir",
+    llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
+    llvm::cl::value_desc("dir"));
+
+// NOLINTNEXTLINE
+opt<std::string> saved_model_tags(
+    "tf-savedmodel-tags",
+    llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
+                   "separated by ','"),
+    llvm::cl::init("serve"));
+
+// NOLINTNEXTLINE
+opt<std::string> saved_model_exported_names(
+    "tf-savedmodel-exported-names",
+    llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
+                   "(the default) means export all."),
+    llvm::cl::init(""));
+
+// NOLINTNEXTLINE
+opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
+                                  llvm::cl::value_desc("filename"),
+                                  llvm::cl::init("-"));
+// NOLINTNEXTLINE
+opt<bool> input_mlir(
+    "input-mlir",
+    llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
+                   "GraphDef format"),
+    llvm::cl::init(false), llvm::cl::Hidden);
+// NOLINTNEXTLINE
+opt<bool> output_mlir(
+    "output-mlir",
+    llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"),
+    llvm::cl::init(false));
+
+// The following approach allows injecting opdefs in addition
+// to those that are already part of the global TF registry  to be linked in
+// prior to importing the graph. The primary goal is for support of custom ops.
+// This is not intended to be a general solution for custom ops for the future
+// but mainly for supporting older models like mobilenet_ssd. More appropriate
+// mechanisms, such as op hints or using functions to represent composable ops
+// like https://github.com/tensorflow/community/pull/113 should be encouraged
+// going forward.
+// NOLINTNEXTLINE
+llvm::cl::list<std::string> custom_opdefs(
+    "tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing "
+                                       "graphdef"));
+
+// Debugging flag to print function mapping in the JSON.
+// NOLINTNEXTLINE
+static opt<bool> print_function_result_mapping(
+    "print-function-result-mapping",
+    llvm::cl::desc(
+        "Print the mapping of function result to json output buffer"),
+    llvm::cl::init(false));
+
+enum TranslationStatus { kTrSuccess, kTrFailure };
+
+static int PrintFunctionResultMapping(const std::string& result) {
+  std::cout << result << std::endl;
+  return kTrSuccess;
+}
+
+int main(int argc, char** argv) {
+  tensorflow::InitMlir y(&argc, &argv);
+
+  llvm::cl::ParseCommandLineOptions(argc, argv,
+                                    "TF GraphDef to TFJS JSON converter\n");
+
+  MLIRContext context;
+  llvm::SourceMgr source_mgr;
+  mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
+
+  StatusOr<mlir::OwningModuleRef> module;
+
+  if (import_saved_model_object_graph || import_saved_model_signature_defs) {
+    if (input_mlir)
+      module = tensorflow::errors::InvalidArgument(
+          "Importing saved model should not have input_mlir set");
+    module = tensorflow::ImportSavedModel(
+        import_saved_model_object_graph, import_saved_model_signature_defs,
+        custom_opdefs, input_file_name, saved_model_tags,
+        saved_model_exported_names, &context);
+  } else {
+    module = tensorflow::LoadFromGraphdefOrMlirSource(
+        input_file_name, input_mlir, custom_opdefs, debug_info_file,
+        input_arrays, input_dtypes, input_shapes, output_arrays,
+        /*prune_unused_nodes=*/true, &source_mgr, &context);
+  }
+
+  // If errors occur, the library call in the above already logged the error
+  // message. So we can just return here.
+  if (!module.ok()) return kTrFailure;
+
+  mlir::PassManager pm(&context);
+
+  tensorflow::AddTFToTFJSConversionPasses(&pm);
+
+  std::string result;
+  auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(),
+                                                   output_mlir, &result, &pm);
+  if (!status.ok()) return kTrFailure;
+
+  std::string error_msg;
+  auto output = mlir::openOutputFile(output_file_name, &error_msg);
+  if (output == nullptr) {
+    llvm::errs() << error_msg << '\n';
+    return kTrFailure;
+  }
+  output->os() << result;
+  output->keep();
+
+  // Print out debugging info related to function mapping.
+  if (print_function_result_mapping) return PrintFunctionResultMapping(result);
+  return kTrSuccess;
+}
diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc
new file mode 100644
index 0000000..7dc9ea0
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc
@@ -0,0 +1,152 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "mlir/Parser.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Support/FileUtilities.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace tensorflow {
+
+using mlir::MLIRContext;
+using mlir::ModuleOp;
+using mlir::OwningModuleRef;
+using stream_executor::port::StatusOr;
+
+namespace {
+tensorflow::Status RegisterCustomOps(
+    const std::vector<std::string>& extra_tf_opdefs) {
+  for (const auto& tf_opdefs_string : extra_tf_opdefs) {
+    tensorflow::OpDef opdef;
+    if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
+                                                           &opdef)) {
+      LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
+      return errors::InvalidArgument("fail to parse extra OpDef");
+    }
+    // Register extra opdefs.
+    tensorflow::OpRegistry::Global()->Register(
+        [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
+          *op_reg_data = tensorflow::OpRegistrationData(opdef);
+          return Status::OK();
+        });
+  }
+  return Status::OK();
+}
+}  // namespace
+
+StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
+    const std::string& input_filename, bool input_mlir,
+    const std::vector<std::string>& extra_tf_opdefs,
+    absl::string_view debug_info_file, absl::string_view input_arrays,
+    absl::string_view input_dtypes, absl::string_view input_shapes,
+    absl::string_view output_arrays, bool prune_unused_nodes,
+    llvm::SourceMgr* source_mgr, MLIRContext* context) {
+  // Set up the input file.
+  std::string error_message;
+  auto file = mlir::openInputFile(input_filename, &error_message);
+  if (!file) {
+    llvm::errs() << error_message << "\n";
+    return errors::InvalidArgument("fail to open input file");
+  }
+
+  if (input_mlir) {
+    source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+    return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
+  }
+
+  TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
+
+  return tensorflow::GraphdefToMlirTranslateFunction(
+      file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
+      input_shapes, output_arrays, /*control_output_arrays=*/"",
+      prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
+      /*graph_as_function=*/false, /*upgrade_legacy=*/true,
+      /*enable_shape_inference=*/true, context);
+}
+
+Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
+                              std::string* result,
+                              mlir::PassManager* pass_manager) {
+  mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
+                                                    /*propagate=*/true);
+  if (failed(pass_manager->run(module))) {
+    return statusHandler.ConsumeStatus();
+  }
+
+  if (export_to_mlir) {
+    llvm::raw_string_ostream os(*result);
+    module.print(os);
+    return Status::OK();
+  }
+
+  return tfjs::MlirToJSONTranslateFunction(module, result)
+             ? Status::OK()
+             : statusHandler.ConsumeStatus();
+}
+
+StatusOr<mlir::OwningModuleRef> ImportSavedModel(
+    bool import_saved_model, bool import_saved_model_v1,
+    const std::vector<std::string>& extra_tf_opdefs,
+    const std::string& input_filename, const std::string& saved_model_tags,
+    const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
+  std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
+  std::vector<std::string> exported_names_in_vector =
+      absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
+  absl::Span<std::string> exported_names(exported_names_in_vector);
+  if (import_saved_model) {
+    auto module = tensorflow::SavedModelObjectGraphToMlirImport(
+        input_filename, tags, absl::Span<std::string>(exported_names), context);
+    if (!module)
+      return tensorflow::errors::InvalidArgument("fail to open input file");
+    TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
+    return module;
+  } else if (import_saved_model_v1) {
+    auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
+        input_filename, tags, exported_names, context);
+
+    if (!module)
+      return tensorflow::errors::InvalidArgument("fail to open input file");
+    TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
+    return module;
+  } else {
+    return tensorflow::errors::InvalidArgument(
+        "Should be either saved model v1 or v2");
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h
new file mode 100644
index 0000000..d68f0e7
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
+#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace tensorflow {
+
+// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR
+// source into a MLIR module. If `input_mlir` is true, load from a MLIR source
+// file; otherwise, load from a GraphDef.
+// Setting prune_unused_nodes to true, would prune unreachable nodes if
+// output_arrays is specified.
+stream_executor::port::StatusOr<mlir::OwningModuleRef>
+LoadFromGraphdefOrMlirSource(
+    const std::string& input_filename, bool input_mlir,
+    const std::vector<std::string>& extra_tf_opdefs,
+    absl::string_view debug_info_file, absl::string_view input_arrays,
+    absl::string_view input_dtypes, absl::string_view input_shapes,
+    absl::string_view output_arrays, bool prune_unused_nodes,
+    llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
+
+// Load Saved model (either v1 or v2) into MLIR.
+stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
+    bool import_saved_model, bool import_saved_model_v1,
+    const std::vector<std::string>& extra_tf_opdefs,
+    const std::string& input_filename, const std::string& saved_model_tags,
+    const std::string& saved_model_exported_names, mlir::MLIRContext* context);
+
+// Taking a MLIR module in TF executor dialect and a set of parameters,
+// applies a set of passes to convert the module to TFJS dialect and
+// serializes the result to JSON string.
+// If `export_to_mlir` is true, the result is exported in MLIR text format,
+// otherwise exported in JSON.
+Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
+                              std::string* result,
+                              mlir::PassManager* pass_manager);
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc
index 8a187cf..9257114 100644
--- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc
@@ -29,11 +29,25 @@
 #include "tfrt/tensor/dense_host_tensor_view.h"
 
 namespace tensorflow {
+namespace {
 
-void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR(
+llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) {
+  if (index_path.size() == 1 && index_path[0].isa<mlir::StringAttr>()) {
+    // TODO(chky): Support cases where index_path is not a single string.
+    return index_path[0].cast<mlir::StringAttr>().getValue();
+  }
+  return "";
+}
+
+}  // namespace
+
+void MapFunctionSignaturesFromTFSavedModelMLIR(
     mlir::ModuleOp module,
     llvm::function_ref<void(
         llvm::StringRef func_name,
+        llvm::ArrayRef<std::pair<llvm::StringRef, llvm::StringRef>>
+            input_names_and_devices,
+        llvm::ArrayRef<llvm::StringRef> output_names,
         llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> global_tensors)>
         map_fn) {
   // Create global_tensors for each functions.
@@ -44,17 +58,38 @@
     auto func_names = mlir::tf_saved_model::GetExportedNames(func);
     if (func_names.empty()) return;
 
-    // Here we walk through each arguments and find out the variables used by
-    // this function.
+    // Here we walk through each arguments and find out the input/output names,
+    // and input devices, variables used by this function.
+    llvm::SmallVector<std::pair<llvm::StringRef, llvm::StringRef>, 4>
+        input_names_and_devices;
     llvm::SmallVector<mlir::tf_saved_model::GlobalTensorOp, 4> global_tensors;
     for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) {
+      if (auto input_index_path = func.getArgAttrOfType<mlir::ArrayAttr>(
+              i, "tf_saved_model.index_path")) {
+        std::pair<llvm::StringRef, llvm::StringRef> name_and_device;
+        name_and_device.first = ProcessIndexPath(input_index_path);
+        if (auto input_device =
+                func.getArgAttrOfType<mlir::StringAttr>(i, "tf.device")) {
+          name_and_device.second = input_device.getValue();
+        }
+        input_names_and_devices.push_back(name_and_device);
+      }
       if (auto variable =
               mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table)) {
         global_tensors.push_back(variable);
       }
     }
 
-    for (auto func_name : func_names) map_fn(func_name, global_tensors);
+    llvm::SmallVector<llvm::StringRef, 4> output_names;
+    for (unsigned i = 0, e = func.getNumResults(); i != e; ++i) {
+      if (auto output_index_path = func.getResultAttrOfType<mlir::ArrayAttr>(
+              i, "tf_saved_model.index_path")) {
+        output_names.push_back(ProcessIndexPath(output_index_path));
+      }
+    }
+
+    for (auto func_name : func_names)
+      map_fn(func_name, input_names_and_devices, output_names, global_tensors);
   });
 }
 
diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h
index de24ea2..06a6c5a 100644
--- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h
+++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h
@@ -57,12 +57,15 @@
   std::string force_data_format;
 };
 
-// Map captured global tensors for each function.
-void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR(
+// Map signatures (eg. input/output names, variables) for each function.
+void MapFunctionSignaturesFromTFSavedModelMLIR(
     mlir::ModuleOp module,
-    llvm::function_ref<
-        void(llvm::StringRef func_name,
-             llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> captures)>
+    llvm::function_ref<void(
+        llvm::StringRef func_name,
+        llvm::ArrayRef<std::pair<llvm::StringRef, llvm::StringRef>>
+            input_names_and_devices,
+        llvm::ArrayRef<llvm::StringRef> output_names,
+        llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> global_tensors)>
         map_fn);
 
 // Compile MLIR in TF saved model dialect into BEF.
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl b/tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
deleted file mode 100644
index cec9968..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
+++ /dev/null
@@ -1,96 +0,0 @@
-load("//third_party/gpus/cuda:build_defs.bzl", "cuda_gpu_select_list")
-
-def _lookup_file(filegroup, path):
-    """Extracts file at (relative) path in filegroup."""
-    for file in filegroup.files.to_list():
-        if file.path.endswith(path):
-            return file
-    return None
-
-def _gen_kernel_image_hdr_impl(ctx):
-    if not ctx.attr.gpu_archs:
-        fail("No GPU architecture specified, use --config=cuda or similar.")
-
-    name = ctx.attr.name
-    tile_sizes = ctx.attr.tile_size.replace("x", ",")
-    same_shape = []
-    if ctx.attr.same_shape:
-        same_shape.append("--same_shape=%s" % ctx.attr.same_shape)
-
-    cubins = []
-    images = []
-    for arch in ctx.attr.gpu_archs:
-        filename = "%s.%s.cubin" % (name, arch)
-        cubin = ctx.actions.declare_file(filename)
-        ctx.actions.run(
-            outputs = [cubin],
-            executable = ctx.executable._tool,
-            arguments = same_shape + [
-                "--tile_sizes=%s" % tile_sizes,
-                "--arch=%s" % arch.split("_")[1],
-                "--output=%s" % cubin.path,
-                ctx.attr.op,
-            ],
-            mnemonic = "compile",
-        )
-        cubins.append(cubin)
-        images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
-
-    # Generate fatbin file from all cubins.
-    fatbin = ctx.actions.declare_file("%s.fatbin" % name)
-    ctx.actions.run(
-        outputs = [fatbin],
-        inputs = cubins,
-        executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
-        arguments = [
-            "--64",
-            "--cmdline=--compile-only",
-            "--link",
-            "--compress-all",
-            "--create=%s" % fatbin.path,
-        ] + images,
-        mnemonic = "fatbinary",
-    )
-
-    bin2c = _lookup_file(ctx.attr._cuda_root, "bin/bin2c")
-    ctx.actions.run_shell(
-        outputs = [ctx.outputs.out],
-        inputs = [fatbin],
-        tools = [bin2c],
-        command = "%s --static --const --type=int --name=%s %s 1> %s" %
-                  (bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
-        mnemonic = "bin2c",
-    )
-
-_gen_kernel_image_hdr = rule(
-    implementation = _gen_kernel_image_hdr_impl,
-    output_to_genfiles = True,
-    attrs = {
-        "op": attr.string(mandatory = True),
-        "tile_size": attr.string(mandatory = True),
-        "same_shape": attr.string(),
-        "out": attr.output(mandatory = True),
-        "symbol": attr.string(mandatory = True),
-        "gpu_archs": attr.string_list(mandatory = True),
-        "_cuda_root": attr.label(
-            default = Label("//third_party/gpus/cuda:cuda_root"),
-        ),
-        "_tool": attr.label(
-            executable = True,
-            default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
-            cfg = "host",
-        ),
-    },
-)
-
-def gen_kernel_image_hdr(name, op, tile_size, same_shape = None):
-    """Generates a C header with fatbin data from a Tensorflow op."""
-    _gen_kernel_image_hdr(
-        name = name,
-        op = op,
-        tile_size = tile_size,
-        same_shape = same_shape,
-        out = "include/tfrt/gpu/ops/tf/%s.h" % name,
-        symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
-        gpu_archs = cuda_gpu_select_list("sm_{}"),
-    )
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
index 46af4e4..b1c4b1b 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
@@ -64,9 +64,9 @@
 
 StatusOr<std::string> GetLibdeviceDir(
     const xla::HloModuleConfig& hlo_module_config) {
-  for (const string& cuda_root : tensorflow::CandidateCudaRoots(
+  for (const std::string& cuda_root : tensorflow::CandidateCudaRoots(
            hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) {
-    string libdevice_dir =
+    std::string libdevice_dir =
         tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
     VLOG(2) << "Looking for libdevice at " << libdevice_dir;
     if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
@@ -136,7 +136,7 @@
     : public mlir::PassWrapper<PropagateStaticKnowledge,
                                mlir::OperationPass<mlir::LLVM::LLVMFuncOp>> {
   explicit PropagateStaticKnowledge(mlir::FunctionType type,
-                                    llvm::ArrayRef<unsigned> same_shape_)
+                                    llvm::ArrayRef<uint32_t> same_shape_)
       : func_type(type), same_shape(same_shape_) {}
 
   void runOnOperation() override {
@@ -152,8 +152,8 @@
         func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1));
     mlir::Value zero = b.create<mlir::LLVM::ConstantOp>(
         func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0));
-    unsigned arg_pos = 0;
-    std::vector<unsigned> positions;
+    uint32_t arg_pos = 0;
+    std::vector<uint32_t> positions;
     for (mlir::Type arg_type : func_type.getInputs()) {
       positions.push_back(arg_pos);
       func.getArgument(arg_pos + 2).replaceAllUsesWith(zero);
@@ -165,13 +165,13 @@
     // can use that here. Simply replace usages of the shape parameters within
     // the function body to a single shape parameter.
     if (!same_shape.empty()) {
-      int first = same_shape.front();
-      int first_offset = positions.at(first);
+      auto first = same_shape.front();
+      auto first_offset = positions.at(first);
       mlir::ShapedType first_type =
           func_type.getInput(first).cast<mlir::ShapedType>();
-      unsigned rank = first_type.getRank();
-      for (int same : same_shape.drop_front(1)) {
-        unsigned same_offset = positions.at(same);
+      uint32_t rank = first_type.getRank();
+      for (auto same : same_shape.drop_front(1)) {
+        uint32_t same_offset = positions.at(same);
         auto same_type = func_type.getInput(same).cast<mlir::ShapedType>();
         if (same_type.getRank() != rank) {
           func.emitOpError() << "same shape constraints on arguments with "
@@ -180,7 +180,7 @@
           signalPassFailure();
         }
 
-        for (int i = 0; i < 2 * rank; ++i) {
+        for (uint32_t i = 0; i < 2 * rank; ++i) {
           // Replace uses for second arg data with first arg.
           auto same_arg = func.getArgument(same_offset + 3 + i);
           auto first_arg = func.getArgument(first_offset + 3 + i);
@@ -191,11 +191,11 @@
   }
 
   mlir::FunctionType func_type;
-  llvm::ArrayRef<unsigned> same_shape;
+  llvm::ArrayRef<uint32_t> same_shape;
 };
 
 Status PropagateStaticShapeKnowledgeToKernel(
-    mlir::ModuleOp module, llvm::ArrayRef<unsigned> same_shape) {
+    mlir::ModuleOp module, llvm::ArrayRef<uint32_t> same_shape) {
   // Grab the original signature from the single function.
   auto func = *module.getBody()->op_begin<mlir::FuncOp>();
 
@@ -218,10 +218,10 @@
 }
 }  // namespace
 
-StatusOr<std::vector<uint8>> tensorflow::kernel_gen::GenerateCubinForTfCode(
-    llvm::StringRef tf_code, std::pair<int, int> compute_capability,
-    llvm::ArrayRef<unsigned> tile_sizes, llvm::ArrayRef<unsigned> same_shape,
-    llvm::ArrayRef<unsigned> unroll_factors) {
+StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
+    llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
+    llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
+    llvm::ArrayRef<uint32_t> unroll_factors) {
   mlir::MLIRContext context;
   context.allowUnregisteredDialects();  // TODO(b/152572127)
   mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
index c874633..47626ba 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
@@ -30,11 +30,12 @@
 
 namespace tensorflow {
 namespace kernel_gen {
-xla::StatusOr<std::vector<uint8>> GenerateCubinForTfCode(
-    llvm::StringRef tf_code, std::pair<int, int> compute_capability = {7, 5},
-    llvm::ArrayRef<unsigned> tile_sizes = {16, 64},
-    llvm::ArrayRef<unsigned> same_shape = {},
-    llvm::ArrayRef<unsigned> unroll_factors = {});
+xla::StatusOr<std::vector<uint8_t>> GenerateCubinForTfCode(
+    llvm::StringRef tf_code,
+    std::pair<int32_t, int32_t> compute_capability = {7, 5},
+    llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
+    llvm::ArrayRef<uint32_t> same_shape = {},
+    llvm::ArrayRef<uint32_t> unroll_factors = {});
 }  // namespace kernel_gen
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
index d39edd8..8edc567 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
@@ -31,9 +31,9 @@
 #include "tensorflow/core/util/command_line_flags.h"
 
 namespace {
-bool ParseStringList(std::string string_list, std::vector<uint32>* result) {
+bool ParseStringList(std::string string_list, std::vector<uint32_t>* result) {
   result->clear();
-  uint32 item;
+  uint32_t item;
   auto items = absl::StrSplit(string_list, ',');
   for (const auto& item_str : items) {
     if (!absl::SimpleAtoi(item_str, &item)) {
@@ -48,10 +48,10 @@
 
 int main(int argc, char** argv) {
   std::string output_file = "foo.bin";
-  int32 architecture = 50;
-  std::vector<uint32> tile_sizes;
-  std::vector<uint32> unroll_factors;
-  std::vector<uint32> same_shape;
+  int32_t architecture = 50;
+  std::vector<uint32_t> tile_sizes;
+  std::vector<uint32_t> unroll_factors;
+  std::vector<uint32_t> same_shape;
 
   auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) {
     if (!ParseStringList(tile_sizes_str, &tile_sizes)) {
@@ -91,8 +91,8 @@
     return 1;
   }
 
-  std::pair<int32, int32> compute_capability(architecture / 10,
-                                             architecture % 10);
+  std::pair<int32_t, int32_t> compute_capability(architecture / 10,
+                                                 architecture % 10);
 
   auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
       argv[1], compute_capability, tile_sizes, same_shape, unroll_factors);
@@ -102,7 +102,7 @@
     return 1;
   }
 
-  std::vector<uint8> cubin_data = cubin.ConsumeValueOrDie();
+  std::vector<uint8_t> cubin_data = cubin.ConsumeValueOrDie();
 
   auto status = tensorflow::WriteStringToFile(
       tensorflow::Env::Default(), output_file,
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index e0e93c3..d9108e8 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -240,8 +240,8 @@
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
-        "@llvm-project//mlir:LoopOps",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Transforms",
     ],
@@ -278,8 +278,8 @@
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
-        "@llvm-project//mlir:LoopOps",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Transforms",
     ],
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 874f4a1..68eafb8 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -1170,9 +1170,22 @@
 //===----------------------------------------------------------------------===//
 
 OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
+  auto input = operand();
+
   // No dimensions to reverse.
-  if (dimensions().getNumElements() == 0) return operand();
-  return nullptr;
+  if (dimensions().getNumElements() == 0) return input;
+
+  llvm::SmallVector<APInt, 5> new_dims;
+  new_dims.reserve(dimensions().getNumElements());
+
+  auto shaped_type = input.getType().cast<ShapedType>();
+  for (auto dim : dimensions().getValues<APInt>()) {
+    if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
+      return nullptr;
+    }
+  }
+
+  return input;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 917a50f..0db9563 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -229,7 +229,7 @@
     BASE_HLO_SqrtOp;
 
 def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
-    [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType],
+    [NoSideEffect, SameOperandsAndResultType],
     HLO_FpOrComplexTensor>, BASE_HLO_TanhOp;
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 99d1da7..cc334d8 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -56,6 +56,20 @@
   return mlir::DenseIntElementsAttr::get(ty, mlir_values);
 }
 
+static mlir::DenseIntElementsAttr ConvertPadding(
+    absl::Span<const std::pair<tensorflow::int64, tensorflow::int64>> padding,
+    mlir::Builder* builder) {
+  llvm::SmallVector<int64_t, 8> elements;
+  elements.reserve(padding.size() * 2);
+  for (const auto& vals : padding) {
+    elements.push_back(vals.first);
+    elements.push_back(vals.second);
+  }
+  auto ty = mlir::RankedTensorType::get(
+      {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
+  return mlir::DenseIntElementsAttr::get(ty, elements);
+}
+
 MlirHloBuilder::~MlirHloBuilder() = default;
 
 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
@@ -79,6 +93,31 @@
   });
 }
 
+StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
+    const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+    absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count, int64 batch_group_count,
+    const PrecisionConfig* precision_config) {
+  TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
+                                         shape, builder_));
+  mlir::ArrayAttr config_attr;
+  if (precision_config)
+    config_attr = ConvertPrecisionConfig(precision_config, &builder_);
+  auto op = builder_.create<mlir::xla_hlo::ConvOp>(
+      loc_, ty, GetValue(lhs), GetValue(rhs),
+      GetI64ElementsAttr(window_strides, &builder_),
+      ConvertPadding(padding, &builder_),
+      GetI64ElementsAttr(lhs_dilation, &builder_),
+      GetI64ElementsAttr(rhs_dilation, &builder_),
+      ConvertConvDimensionNumbers(dimension_numbers, &builder_),
+      builder_.getI64IntegerAttr(feature_group_count),
+      builder_.getI64IntegerAttr(batch_group_count), config_attr);
+  return MakeXlaOp(op);
+}
+
 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
     const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index dbcb685..5a84d60 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -110,6 +110,16 @@
  private:
   XlaOp ConstantLiteral(const LiteralSlice& literal) override;
 
+  StatusOr<XlaOp> ConvGeneralDilatedInternal(
+      const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+      absl::Span<const int64> window_strides,
+      absl::Span<const std::pair<int64, int64>> padding,
+      absl::Span<const int64> lhs_dilation,
+      absl::Span<const int64> rhs_dilation,
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      int64 feature_group_count, int64 batch_group_count,
+      const PrecisionConfig* precision_config) override;
+
   StatusOr<XlaOp> TransposeInternal(
       const Shape& shape, XlaOp operand,
       absl::Span<const int64> permutation) override;
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 21668b7..228a26b 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -986,6 +986,21 @@
     return LowerFunctionCall(&call_op, builder, &value_map);
   }
 
+  if (auto op = dyn_cast<mlir::TensorCastOp>(inst)) {
+    Value operand = op.getOperand();
+    auto ty = operand.getType().dyn_cast<ShapedType>();
+    // If this was a cast from a static shaped tensors, then it is a noop for
+    // export to HLO and we can use the operand.
+    if (!ty || !ty.hasStaticShape()) {
+      inst->emitOpError()
+          << "requires static shaped operand for HLO translation";
+      return failure();
+    }
+
+    value_map[op.getResult()] = value_map[operand];
+    return success();
+  }
+
   // TODO(jpienaar): This doesn't support layouts yet.
   if (matchPattern(inst, m_Constant(&const_attr))) {
     auto literal_or = CreateLiteralFromAttr(const_attr);
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
index 262533b..53296b2 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
@@ -1,4 +1,4 @@
-// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
+// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure
 
 // CHECK-LABEL: func @attrs
 func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@@ -13,33 +13,42 @@
 
 // -----
 
+func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+  return %arg0 : tensor<4xf32>
+}
+//      CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
+// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
+
+// -----
+
 // CHECK-LABEL: func @func_op_long
 func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
-  // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
-  // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
-  // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
   %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
   %2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
   %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
   %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
   %5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
-  // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
-  // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
-  // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
   return %5 : tensor<4xf32>
-  // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
 }
+//      CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
+// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
+// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
+// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
+// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
+// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
+// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
 
 // -----
 
@@ -47,20 +56,20 @@
 func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
              %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
   // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
-  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
-  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
+  // CHECK-NEXT:  %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
   %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
   %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
   %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
       : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
+  // CHECK-NEXT:  %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
   %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
   %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
       : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
+  // CHECK-NEXT:  dealloc %[[ADD_RESULT]] : memref<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
   tensor_store %tensor_result, %result : memref<2x2xf32>
-  // CHECK-NEXT:  dealloc %[[ADD_RESULT]] : memref<2x2xf32>
   // CHECK-NEXT:  dealloc %[[MUL_RESULT]] : memref<2x2xf32>
   // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
   "xla_lhlo.terminator"() : () -> ()
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir
index aa949a0..a856ee5 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir
@@ -530,3 +530,15 @@
 // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64):
 // CHECK-NEXT:   %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32
 // CHECK-NEXT:   linalg.yield %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: func @convert_f32_to_i32
+func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
+  %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
+  return %result : tensor<2x2xi32>
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
+// CHECK-NEXT:   %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
index cda1dc4..6a2b68a 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
@@ -8,7 +8,9 @@
 // CHECK-SAME: ) {
 func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
   // The only expected instruction is a copy from the input into the output.
-  // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][][] : memref<16xi8> to memref<2x2xf32>
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[C02:.*]] = constant 0 : index
+  // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C02]]][] : memref<16xi8> to memref<2x2xf32>
   // CHECK: xla_lhlo.copy
   // CHECK-SAME: %[[ARG0]], %[[OUTPUT]]
   return %value : tensor<2x2xf32>
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index aef9b17..a5353be 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1596,6 +1596,44 @@
   return %0, %1 : tensor<i32>, tensor<i32>
 }
 
+
+//===----------------------------------------------------------------------===//
+// ReverseV2 op legalization.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @reverse_func_32
+func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> {
+  %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>)
+
+  // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
+  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32>
+
+  // CHECK: return [[VAL]] : tensor<5xi32>
+  return %reversed : tensor<5xi32>
+}
+
+// CHECK-LABEL: @reverse_func_64
+func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> {
+  %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>)
+
+  // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
+  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32>
+
+  // CHECK: return [[VAL]] : tensor<5xi32>
+  return %reversed : tensor<5xi32>
+}
+
+// CHECK-LABEL: @reverse_func_neg
+func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> {
+  %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
+
+  // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>}
+  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32>
+
+  // CHECK: return [[VAL]] : tensor<5x5xi32>
+  return %reversed : tensor<5x5xi32>
+}
+
 //===----------------------------------------------------------------------===//
 // StatefulPartitionedCall op legalization.
 //===----------------------------------------------------------------------===//
@@ -4074,6 +4112,41 @@
   return %0 : tensor<4x16xf32>
 }
 
+// CHECK-LABEL: inplace_update_one
+func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> {
+  // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0>
+  // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+  // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]])
+  // CHECK-DAG: [[UPDATE:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]])
+  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32>
+
+  // CHECK: return [[UPDATE]]
+  return %0 : tensor<8x4xf32>
+}
+
+// CHECK-LABEL: inplace_update_three
+func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> {
+  // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0>
+  // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: [[SLICE3:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: [[SLICE4:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+  // CHECK-DAG: [[SLICE5:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+  // CHECK-DAG: [[SLICE6:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+  // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]])
+  // CHECK-DAG: [[RESHAPE2:%.+]] = "xla_hlo.reshape"([[SLICE2]])
+  // CHECK-DAG: [[RESHAPE3:%.+]] = "xla_hlo.reshape"([[SLICE3]])
+  // CHECK-DAG: [[UPDATE1:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]])
+  // CHECK-DAG: [[UPDATE2:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]])
+  // CHECK-DAG: [[UPDATE3:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]])
+  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32>
+
+  // CHECK:  return [[UPDATE3]] : tensor<8x8x4xf32>
+  return %0 : tensor<8x8x4xf32>
+}
+
+
 // CHECK-LABEL: xla_dynamic_update_slice
 func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> {
   // CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
@@ -4097,6 +4170,21 @@
 }
 
 //===----------------------------------------------------------------------===//
+// AllToAll op legalizations.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @alltoall_basic
+func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> {
+  %group_assignment = "tf.Const" () {
+    value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32>
+  } : () -> tensor<3x4xi32>
+  %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} :  (tensor<10xf32>, tensor<3x4xi32>)  -> tensor<10xf32>
+  // CHECK: xla_hlo.all_to_all
+  // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64>
+  return %result : tensor<10xf32>
+}
+
+//===----------------------------------------------------------------------===//
 // Cumsum op legalizations.
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
index 3605e8e..bb8010b 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
@@ -411,6 +411,19 @@
 
 // -----
 
+// CHECK-LABEL: func @convert_f32_to_i32
+func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
+  "xla_lhlo.convert"(%input, %result)
+      : (memref<2x2xf32>, memref<2x2xi32>) -> ()
+  return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i32):
+// CHECK-NEXT:   %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
+// CHECK-NEXT:   linalg.yield %[[RESULT]] : i32
+
+// -----
+
 // CHECK-LABEL: func @cos
 func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
   "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir
index 4050340..2340650 100644
--- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir
@@ -20,6 +20,17 @@
 
 // -----
 
+// CHECK-LABEL: @addBroadcastEqual
+func @addBroadcastEqual(%arg0: tensor<4x1xf32>, %arg1: tensor<1x4xf32>) -> tensor<4x4xf32> {
+  // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4xf32>
+  // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<4x4xf32>
+  // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4x4xf32>
+  %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32>
+  return %0 : tensor<4x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @addBroadcastMultidimension
 func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
   // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32>
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index ed06863..15fa915 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -1044,3 +1044,16 @@
 // CHECK: ENTRY
 // CHECK: %[[ARG0:.*]] = u8[4] parameter(0)
 //  ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
+
+// -----
+
+// CHECK:  HloModule
+func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
+  %0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32>
+  return %1 : tensor<*xi32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[ARG0:.*]] = s32[4] parameter(0)
+//  ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index aa29241..10f3576 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -27,6 +27,7 @@
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
@@ -39,16 +40,11 @@
 namespace {
 
 constexpr StringRef kTempBufferAttr = "temp";
-
-/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc.
-Operation* FindInsertionPointForCopy(Value value) {
-  for (const auto& user : value.getUsers()) {
-    if (auto dealloc = dyn_cast<DeallocOp>(user)) {
-      return user;
-    }
-  }
-  return nullptr;
-}
+template <typename T>
+using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
+using StdReturnOpConverter =
+    NonVoidToVoidReturnOpConverter<mlir::ReturnOp, xla_lhlo::TerminatorOp,
+                                   xla_lhlo::CopyOp>;
 
 Value InsertDynamicAllocAndDealloc(Location loc, Value result,
                                    Value shape_operand,
@@ -92,8 +88,9 @@
   return alloc;
 }
 
-Value InsertAllocAndDealloc(Location loc, Value result,
-                            ConversionPatternRewriter* rewriter) {
+Value InsertAlloc(Location loc, OpResult result,
+                  BufferAssignmentPlacer* bufferAssignment,
+                  ConversionPatternRewriter* rewriter) {
   auto result_type = result.getType().dyn_cast<ShapedType>();
   if (!result_type || !result_type.hasStaticShape()) {
     result.getDefiningOp()->emitOpError()
@@ -101,31 +98,21 @@
   }
   auto memref_type =
       MemRefType::get(result_type.getShape(), result_type.getElementType());
-
-  Operation* op = result.getDefiningOp();
-  auto block = op->getBlock();
-
-  OpBuilder allocBuilder(op);
-  allocBuilder.setInsertionPointToStart(block);  // Inserting at the beginning
-  auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
-
-  alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
-
-  allocBuilder.setInsertionPoint(block, std::prev(block->end()));
-  allocBuilder.create<DeallocOp>(loc, alloc);
-
+  OpBuilder::InsertionGuard guard(*rewriter);
+  rewriter->restoreInsertionPoint(
+      bufferAssignment->computeAllocPosition(result));
+  auto alloc = rewriter->create<AllocOp>(loc, memref_type);
   return alloc;
 }
 
 template <typename HloOpTy>
-class HloToLhloOpConverter : public ConversionPattern {
+class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
  public:
-  explicit HloToLhloOpConverter(MLIRContext* context)
-      : ConversionPattern(HloOpTy::getOperationName(), 1, context) {}
-
+  using BaseOpConversion<HloOpTy>::BaseOpConversion;
   LogicalResult matchAndRewrite(
-      Operation* op, ArrayRef<Value> operands,
+      HloOpTy hloOp, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
+    Operation* op = hloOp.getOperation();
     const auto& original_results = op->getResults();
     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
     for (auto result : llvm::enumerate(original_results)) {
@@ -135,8 +122,8 @@
         return failure();
       }
       if (resultType.hasStaticShape()) {
-        buffer_args.push_back(
-            InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter));
+        buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
+                                          this->bufferAssignment, &rewriter));
       } else {
         SmallVector<Value, 1> results_shape;
         auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
@@ -156,9 +143,9 @@
 };
 
 struct HloToLhloDynamicBroadcastInDimOpConverter
-    : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
+    : public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
  public:
-  using OpConversionPattern::OpConversionPattern;
+  using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
 
   LogicalResult matchAndRewrite(
       xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
@@ -175,10 +162,9 @@
   }
 };
 
-struct HloToLhloReduceOpConverter
-    : public OpConversionPattern<xla_hlo::ReduceOp> {
+struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
  public:
-  using OpConversionPattern::OpConversionPattern;
+  using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
 
   LogicalResult matchAndRewrite(
       xla_hlo::ReduceOp op, ArrayRef<Value> operands,
@@ -194,7 +180,8 @@
     const auto& original_results = op.getResults();
     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
     for (auto result : original_results) {
-      buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
+      buffer_args.push_back(
+          InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
     }
     auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
         loc, llvm::None, buffer_args, op.getAttrs());
@@ -230,12 +217,12 @@
   }
 };
 
-class HloToLhloTensorLoadOpConverter : public ConversionPattern {
+class HloToLhloTensorLoadOpConverter
+    : public BaseOpConversion<mlir::TensorLoadOp> {
  public:
-  explicit HloToLhloTensorLoadOpConverter(MLIRContext* context)
-      : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}
+  using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
   LogicalResult matchAndRewrite(
-      Operation* op, ArrayRef<Value> operands,
+      mlir::TensorLoadOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
     rewriter.replaceOp(op, operands);
     return success();
@@ -243,13 +230,13 @@
 };
 
 // TODO(b/137624192): Rewrite into a copy and elide copy if possible.
-class HloToLhloTensorStoreOpConverter : public ConversionPattern {
+class HloToLhloTensorStoreOpConverter
+    : public BaseOpConversion<mlir::TensorStoreOp> {
  public:
-  explicit HloToLhloTensorStoreOpConverter(MLIRContext* context)
-      : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}
+  using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
 
   LogicalResult matchAndRewrite(
-      Operation* op, ArrayRef<Value> operands,
+      mlir::TensorStoreOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const final {
     rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
         op, llvm::None, operands.front(), operands.back());
@@ -291,7 +278,6 @@
 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 //     "xla_lhlo.multiply"(%0, %arg0, %arg3) :
 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
-//     dealloc %0 : memref<2x2xf32>
 //     "xla_lhlo.terminator"() : () -> ()
 //   }) : () -> ()
 //   return
@@ -313,14 +299,13 @@
 //               %arg1: memref<4xf32>,
 //               %arg2: memref<4xf32>) {
 //   %0 = alloc() : memref<4xf32>
-//   %1 = alloc() : memref<4xf32>
+
 //   "xla_lhlo.maximum"(%arg0, %arg1, %0) :
 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
+//   %1 = alloc() : memref<4xf32>
 //   "xla_lhlo.add"(%arg0, %0, %1) :
 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
 //   "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
-//   dealloc %0 : memref<4xf32>
-//   dealloc %1 : memref<4xf32>
 //   "xla_lhlo.terminator"() : () -> ()
 // }
 
@@ -346,101 +331,25 @@
     });
 
     auto module = getOperation();
-    populateHLOToLHLOConversionPattern(module.getContext(), &patterns);
-
-    // Do partial conversion so we can have unknown ops in tests.
-    if (failed(applyPartialConversion(module, target, patterns, nullptr))) {
-      signalPassFailure();
-    }
+    BufferAssignmentTypeConverter converter;
+    module.walk([&](FuncOp func) {
+      BufferAssignmentPlacer bufferAssignment(func);
+      OwningRewritePatternList patterns;
+      populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
+                                         &converter, &patterns);
+      return WalkResult(
+          applyPartialConversion(func, target, patterns, &converter));
+    });
   }
 };
-
-Type ConvertType(Type t) {
-  if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
-    return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-  }
-  return t;
-}
-
 }  // namespace
 
-/// Transforms FuncOp arguments and results from tensors to buffers. Tensor
-/// results are converted to memrefs and appended to the argument list.
-class HloToLhloFuncOpConverter : public OpConversionPattern<FuncOp> {
- public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult matchAndRewrite(
-      FuncOp funcOp, ArrayRef<Value> operands,
-      ConversionPatternRewriter& rewriter) const final {
-    if (funcOp.getBody().getBlocks().size() > 1) {
-      funcOp.emitOpError() << "tensor to buffer conversion expects a single "
-                              "block in the region containing the operation";
-      return failure();
-    }
-
-    auto funcType = funcOp.getType();
-
-    TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
-    for (auto argType : llvm::enumerate(funcType.getInputs())) {
-      conversion.addInputs(argType.index(), ConvertType(argType.value()));
-    }
-    for (auto resType : funcType.getResults()) {
-      conversion.addInputs(ConvertType(resType));
-    }
-    rewriter.updateRootInPlace(funcOp, [&] {
-      funcOp.setType(
-          rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
-      rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
-    });
-    return success();
-  }
-};
-
-/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each
-/// result to the corresponding buffer argument.
-class StdToLhloReturnOpConverter : public OpConversionPattern<mlir::ReturnOp> {
- public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult matchAndRewrite(
-      mlir::ReturnOp returnOp, ArrayRef<Value> operands,
-      ConversionPatternRewriter& rewriter) const final {
-    auto numReturnValues = returnOp.getNumOperands();
-    auto funcOp = returnOp.getParentOfType<FuncOp>();
-    auto numFuncArgs = funcOp.getNumArguments();
-    auto loc = returnOp.getLoc();
-
-    for (auto operand : llvm::enumerate(operands)) {
-      auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
-      auto dstBuffer = funcOp.getArgument(returnArgNumber);
-      if (dstBuffer == operand.value()) {
-        continue;
-      }
-
-      auto dealloc = FindInsertionPointForCopy(operand.value());
-
-      if (dealloc == nullptr) {
-        returnOp.emitOpError()
-            << "Missing dealloc for operand " << operand.index();
-        return failure();
-      }
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPoint(dealloc);
-      rewriter.create<xla_lhlo::CopyOp>(loc, llvm::None, operand.value(),
-                                        funcOp.getArgument(returnArgNumber));
-    }
-    rewriter.replaceOpWithNewOp<xla_lhlo::TerminatorOp>(returnOp);
-    return success();
-  }
-};
-
-void populateHLOToLHLOConversionPattern(MLIRContext* context,
-                                        OwningRewritePatternList* patterns) {
+void populateHLOToLHLOConversionPattern(
+    MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
+    TypeConverter* converter, OwningRewritePatternList* patterns) {
   // clang-format off
   patterns->insert<
       HloToLhloDynamicBroadcastInDimOpConverter,
-      HloToLhloFuncOpConverter,
       HloToLhloOpConverter<xla_hlo::AbsOp>,
       HloToLhloOpConverter<xla_hlo::AddOp>,
       HloToLhloOpConverter<xla_hlo::AndOp>,
@@ -472,8 +381,9 @@
       HloToLhloReduceOpConverter,
       HloToLhloTensorLoadOpConverter,
       HloToLhloTensorStoreOpConverter,
-      StdToLhloReturnOpConverter
-  >(context);
+      FunctionAndBlockSignatureConverter,
+      StdReturnOpConverter
+  >(context, bufferAssignment, converter);
   // clang-format on
 }
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index de808bc..a0a5e47 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -2590,6 +2590,21 @@
   }
 };
 
+ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
+  auto int_attr = attr.cast<DenseIntElementsAttr>();
+  auto type = val.getType().cast<ShapedType>();
+
+  SmallVector<int64_t, 6> axis;
+  axis.reserve(int_attr.getNumElements());
+
+  int64_t rank = type.getRank();
+  for (auto val : int_attr.getValues<APInt>()) {
+    axis.push_back((val.getSExtValue() + rank) % rank);
+  }
+
+  return builder->getI64TensorAttr(axis);
+}
+
 /// Converts the LinSpace tensorflow op to a xla_hlo.iota op with a scaling
 /// and offset applied to generate the linspace values. The output tensor needs
 /// to have a static shape.  The implementation is defined in C++ because there
@@ -4182,6 +4197,68 @@
   }
 };
 
+// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
+class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
+                                PatternRewriter &rewriter) const override {
+    auto input = op.x();
+    auto indices = op.i();
+    auto updates = op.v();
+
+    // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
+    // on the contents of `x`.
+    auto input_type = input.getType().cast<ShapedType>();
+    auto updates_type = updates.getType().cast<ShapedType>();
+    auto indices_type = indices.getType().cast<ShapedType>();
+    if (!indices_type.hasStaticShape()) return failure();
+
+    if (indices_type.getRank() != 1) return failure();
+
+    SmallVector<Type, 4> unpacked_indices_type(
+        indices_type.getDimSize(0),
+        RankedTensorType::get({}, indices_type.getElementType()));
+    auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(64), 0);
+    auto unpacked_indices = rewriter.create<TF::UnpackOp>(
+        op.getLoc(), unpacked_indices_type, indices, zero_attr);
+
+    SmallVector<int64_t, 4> split_updates_shape;
+    split_updates_shape.append(updates_type.getShape().begin(),
+                               updates_type.getShape().end());
+    split_updates_shape.front() = 1;
+    SmallVector<Type, 4> split_updates_type;
+    split_updates_type.resize(
+        updates_type.getShape().front(),
+        RankedTensorType::get(split_updates_shape,
+                              updates_type.getElementType()));
+
+    auto cst =
+        rewriter.create<xla_hlo::ConstOp>(op.getLoc(), zero_attr).getResult();
+    auto split_updates = rewriter.create<TF::SplitOp>(
+        op.getLoc(), split_updates_type, cst, updates);
+
+    SmallVector<Value, 6> input_indices;
+    input_indices.resize(input_type.getRank(), cst);
+
+    SmallVector<int64_t, 6> starts(updates_type.getRank(), 0);
+    SmallVector<int64_t, 6> strides(updates_type.getRank(), 1);
+    SmallVector<int64_t, 6> limits(updates_type.getShape().begin(),
+                                   updates_type.getShape().end());
+
+    for (auto pair :
+         llvm::zip(unpacked_indices.output(), split_updates.output())) {
+      input_indices.front() = std::get<0>(pair);
+      input = rewriter.create<xla_hlo::DynamicUpdateSliceOp>(
+          op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
+    }
+
+    rewriter.replaceOp(op, input);
+    return success();
+  }
+};
+
 // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
 class ConvertXlaDynamicUpdateSliceOp
     : public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
@@ -4863,12 +4940,13 @@
       ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp,
       ConvertEinsumOp, ConvertFusedBatchNormGradOp,
       ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
-      ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
-      ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
-      ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
-      ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
-      ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op,
-      ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
+      ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
+      ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
+      ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
+      ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
+      ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
+      ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
+      ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
       ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
       ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
       ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index d53dbdc..2a27c1f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -274,6 +274,13 @@
             (CastElementsToI64Elements $group_assignment))>;
 
 //===----------------------------------------------------------------------===//
+// All2All op patterns.
+//===----------------------------------------------------------------------===//
+
+def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count),
+          (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>;
+
+//===----------------------------------------------------------------------===//
 // FFT op patterns.
 //===----------------------------------------------------------------------===//
 
@@ -514,6 +521,16 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Reverse op patterns.
+//===----------------------------------------------------------------------===//
+
+// Handles axis conversion for TF reverse.
+def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">;
+
+def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)),
+    (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>;
+
+//===----------------------------------------------------------------------===//
 // Ternary op patterns.
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index be6ba16..86a2def 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -102,6 +102,7 @@
     TypeID::get<TF::CastOp>(),
     TypeID::get<TF::ClipByValueOp>(),
     TypeID::get<TF::ComplexAbsOp>(),
+    TypeID::get<TF::ConjugateTransposeOp>(),
     TypeID::get<TF::CoshOp>(),
     TypeID::get<TF::CrossOp>(),
     TypeID::get<TF::DataFormatDimMapOp>(),
@@ -135,9 +136,11 @@
     TypeID::get<TF::MulOp>(),
     TypeID::get<TF::NegOp>(),
     TypeID::get<TF::NotEqualOp>(),
+    TypeID::get<TF::PadOp>(),
     TypeID::get<TF::PlaceholderWithDefaultOp>(),
     TypeID::get<TF::PowOp>(),
     TypeID::get<TF::RealDivOp>(),
+    TypeID::get<TF::ReciprocalOp>(),
     TypeID::get<TF::ReciprocalGradOp>(),
     TypeID::get<TF::Relu6GradOp>(),
     TypeID::get<TF::RightShiftOp>(),
@@ -162,6 +165,8 @@
     TypeID::get<TF::TruncateModOp>(),
     TypeID::get<TF::UnpackOp>(),
     TypeID::get<TF::XdivyOp>(),
+    TypeID::get<TF::XlaBroadcastHelperOp>(),
+    TypeID::get<TF::XlaConvOp>(),
     TypeID::get<TF::XlaDotOp>(),
     TypeID::get<TF::XlaPadOp>(),
     TypeID::get<TF::Xlog1pyOp>(),
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
index e6f3ac0..f0eb3cc 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
@@ -21,7 +21,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // from @llvm-project
-#include "mlir/Dialect/LoopOps/LoopOps.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
@@ -112,7 +112,7 @@
       auto step = rewriter.create<mlir::ConstantOp>(
           loc, rewriter.getIndexType(),
           rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
-      auto loop = rewriter.create<mlir::loop::ForOp>(loc, zero, upper, step);
+      auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
 
       rewriter.setInsertionPointToStart(loop.getBody());
       // Compute memrefs for the value to reduce. This makes it easier to just
@@ -173,8 +173,7 @@
     OwningRewritePatternList patterns;
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
-                           gpu::GPUDialect, loop::LoopOpsDialect,
-                           XlaLhloDialect>();
+                           gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
     target.addIllegalOp<ReduceOp>();
     auto func = getFunction();
     patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc
index 54b3acd..c5f5b39 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc
@@ -18,7 +18,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // from @llvm-project
-#include "mlir/Dialect/LoopOps/LoopOps.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -64,12 +64,12 @@
 // into a reduction operator of loop.reduce by doing buffer allocation for
 // scalar arguments and the result of `loop.reduce` to make it compatible with
 // LHLO ops.
-void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
+void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
                                 Block* lhlo_block, OpBuilder* b) {
   Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
   OpBuilder::InsertionGuard guard(*b);
   b->setInsertionPointToStart(&loop_reduce_op_body);
-  b->create<loop::ReduceReturnOp>(
+  b->create<scf::ReduceReturnOp>(
       loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
                                      lhlo_block, b));
 }
@@ -136,9 +136,9 @@
   return mapped_ivs;
 }
 
-// Returns loop::Parallel over a shaped value with static or dynamic shape.
-loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
-                                   OpBuilder* b) {
+// Returns scf::Parallel over a shaped value with static or dynamic shape.
+scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
+                                  OpBuilder* b) {
   Value zero = b->create<ConstantIndexOp>(loc, 0);
   Value one = b->create<ConstantIndexOp>(loc, 1);
 
@@ -151,10 +151,10 @@
     lower.push_back(zero);
     step.push_back(one);
   }
-  return b->create<loop::ParallelOp>(loc, lower, upper, step);
+  return b->create<scf::ParallelOp>(loc, lower, upper, step);
 }
 
-// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp.
+// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
 // The outper `ParallelOp` refers to the parallel loops if there are
 // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
 // contains the reduction operator.
@@ -197,7 +197,7 @@
     // TODO(b/137624192) Implement variadic reduce.
     if (xla_reduce_op.out().size() != 1) return failure();
 
-    loop::ReduceOp reduce_op =
+    scf::ReduceOp reduce_op =
         CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
     ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
                                &xla_reduce_op.body().front(), &rewriter);
@@ -225,7 +225,7 @@
   //    } : f32
   //    loop.yield
   //  }
-  loop::ReduceOp CreateReduceOpInNestedParallelLoops(
+  scf::ReduceOp CreateReduceOpInNestedParallelLoops(
       xla_lhlo::ReduceOp xla_reduce_op,
       ConversionPatternRewriter* rewriter) const {
     auto loc = xla_reduce_op.getLoc();
@@ -254,13 +254,13 @@
     SmallVector<Value, 1> init_value = {
         rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
     // Outer ParallelOp is not needed if it is a reduction across all dims.
-    loop::ParallelOp outer;
+    scf::ParallelOp outer;
     if (!parallel_lower.empty()) {
-      outer = rewriter->create<loop::ParallelOp>(loc, parallel_lower,
-                                                 parallel_upper, parallel_step);
+      outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
+                                                parallel_upper, parallel_step);
       rewriter->setInsertionPointToStart(outer.getBody());
     }
-    loop::ParallelOp inner = rewriter->create<loop::ParallelOp>(
+    scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
         loc, reduce_lower, reduce_upper, reduce_step, init_value);
     Value reduction_result = *inner.getResults().begin();
 
@@ -294,7 +294,7 @@
     rewriter->setInsertionPointToStart(inner.getBody());
     Value elem = rewriter->create<mlir::LoadOp>(
         loc, *xla_reduce_op.operands().begin(), indices);
-    return rewriter->create<loop::ReduceOp>(loc, elem);
+    return rewriter->create<scf::ReduceOp>(loc, elem);
   }
 };
 
@@ -314,8 +314,8 @@
 //     accumulator = reduction_operator(output[O], value)
 //   output[O] = accumulator
 //
-// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a
-// loop::ReduceOp.
+// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
+// scf::ReduceOp.
 // The outper `ParallelOp` refers to the parallel loops that traverese output
 // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
 // reduction windows and `ReduceOp` contains the reduction operator.
@@ -366,12 +366,12 @@
   LogicalResult matchAndRewrite(
       xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
       ConversionPatternRewriter& rewriter) const final {
-    loop::ParallelOp output_loop, window_loop;
+    scf::ParallelOp output_loop, window_loop;
     std::tie(output_loop, window_loop) =
         CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
                                                      &rewriter);
 
-    loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
+    scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
         xla_reduce_window_op, output_loop, window_loop, &rewriter);
 
     ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
@@ -381,7 +381,7 @@
   }
 
  private:
-  std::pair<loop::ParallelOp, loop::ParallelOp>
+  std::pair<scf::ParallelOp, scf::ParallelOp>
   CreateParallelLoopsToTraverseOutputAndWindow(
       xla_lhlo::ReduceWindowOp xla_reduce_window_op,
       ConversionPatternRewriter* rewriter) const {
@@ -405,7 +405,7 @@
       window_upper.push_back(
           rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
     }
-    auto window_loop = rewriter->create<loop::ParallelOp>(
+    auto window_loop = rewriter->create<scf::ParallelOp>(
         loc, window_lower, window_upper, window_step, init_value);
 
     Value reduction_result = *window_loop.getResults().begin();
@@ -414,9 +414,9 @@
     return std::make_pair(output_loop, window_loop);
   }
 
-  loop::ReduceOp CreateReduceOpInNestedParallelLoops(
+  scf::ReduceOp CreateReduceOpInNestedParallelLoops(
       xla_lhlo::ReduceWindowOp xla_reduce_window_op,
-      loop::ParallelOp output_loop, loop::ParallelOp window_loop,
+      scf::ParallelOp output_loop, scf::ParallelOp window_loop,
       ConversionPatternRewriter* rewriter) const {
     rewriter->setInsertionPointToStart(window_loop.getBody());
     auto loc = xla_reduce_window_op.getLoc();
@@ -436,20 +436,20 @@
         xla_reduce_window_op, output_loop.getInductionVars(),
         window_loop.getInductionVars(), rewriter);
 
-    auto elem_or_init = rewriter->create<loop::IfOp>(
+    auto elem_or_init = rewriter->create<scf::IfOp>(
         loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
         /*withElseRegion=*/true);
 
     OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
     Value elem = then_builder.create<mlir::LoadOp>(
         loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
-    then_builder.create<loop::YieldOp>(loc, elem);
+    then_builder.create<scf::YieldOp>(loc, elem);
 
     OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
-    else_builder.create<loop::YieldOp>(loc, *window_loop.initVals().begin());
+    else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
 
-    return rewriter->create<loop::ReduceOp>(loc,
-                                            *elem_or_init.results().begin());
+    return rewriter->create<scf::ReduceOp>(loc,
+                                           *elem_or_init.results().begin());
   }
 };
 
@@ -490,7 +490,7 @@
       ConversionPatternRewriter& rewriter) const final {
     auto loc = s_and_s_op.getLoc();
     InitializeOutput(s_and_s_op, &rewriter);
-    loop::ParallelOp loop_over_src =
+    scf::ParallelOp loop_over_src =
         MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
     rewriter.setInsertionPointToStart(loop_over_src.getBody());
 
@@ -520,7 +520,7 @@
     auto loc = s_and_s_op.getLoc();
     Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
 
-    loop::ParallelOp loop_over_output =
+    scf::ParallelOp loop_over_output =
         MakeLoopOverShape(loc, s_and_s_op.out(), b);
     OpBuilder::InsertionGuard guard(*b);
     b->setInsertionPointToStart(loop_over_output.getBody());
@@ -531,10 +531,10 @@
   struct WindowLoops {
     SmallVector<Value, 2> selected_ivs;
     SmallVector<Value, 2> window_ivs;
-    loop::ForOp inner_loop;
+    scf::ForOp inner_loop;
   };
   WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
-                                loop::ParallelOp loop_over_src,
+                                scf::ParallelOp loop_over_src,
                                 OpBuilder* b) const {
     auto loc = s_and_s_op.getLoc();
     Value zero = b->create<ConstantIndexOp>(loc, 0);
@@ -558,12 +558,12 @@
          s_and_s_op.window_dimensions()->getIntValues()) {
       Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
       result.inner_loop =
-          b->create<loop::ForOp>(loc, zero, upper, one, iter_args);
+          b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
       if (b->getInsertionBlock() == loop_over_src.getBody()) {
         ip = b->saveInsertionPoint();
         result.selected_ivs = result.inner_loop.getResults().take_front(rank);
       } else {
-        b->create<loop::YieldOp>(loc, result.inner_loop.getResults());
+        b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
       }
       b->setInsertionPointToStart(result.inner_loop.getBody());
       iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
@@ -599,7 +599,7 @@
   };
 
   SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
-                                  loop::ParallelOp loop_over_src,
+                                  scf::ParallelOp loop_over_src,
                                   OpBuilder* b) const {
     auto loc = s_and_s_op.getLoc();
 
@@ -614,7 +614,7 @@
 
     IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
 
-    auto if_in_bounds = inner_loop_b.create<loop::IfOp>(
+    auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
         loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
         /*withElseRegion=*/true);
 
@@ -623,16 +623,16 @@
       OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
       auto select_or_init_results = SelectOrInitialize(
           s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
-      in_bounds_then_b.create<loop::YieldOp>(loc, select_or_init_results);
+      in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
     }
 
     // Case when we are in the pad.
     {
       OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
-      in_bounds_else_b.create<loop::YieldOp>(loc, ivs_val_flag.to_vector());
+      in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
     }
 
-    inner_loop_b.create<loop::YieldOp>(loc, if_in_bounds.getResults());
+    inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
     return window_loops.selected_ivs;
   }
 
@@ -647,8 +647,8 @@
     Value operand_elem =
         b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
     auto if_init =
-        b->create<loop::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
-                              /*withElseRegion=*/true);
+        b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
+                             /*withElseRegion=*/true);
     // Init == true, i.e. iter args are already initialized with a selected
     // element in boundaries of the operand. Select function has to be computed
     // here.
@@ -660,32 +660,31 @@
           ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
                                     &lhlo_select, &if_init_then_b);
 
-      auto if_pred =
-          if_init_then_b.create<loop::IfOp>(loc, iter_arg_types, pred,
-                                            /*withElseRegion=*/true);
+      auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
+                                                      /*withElseRegion=*/true);
 
       // Pred == true, therefore pack newly selected ivs, val and init flag back
       // to iter_args and return.
       {
         OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
-        if_pred_then_b.create<loop::YieldOp>(
+        if_pred_then_b.create<scf::YieldOp>(
             loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
       }
 
       // Pred == false, therefore return old iter_args.
       {
         OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
-        if_pred_else_b.create<loop::YieldOp>(loc, ivs_val_flag->to_vector());
+        if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
       }
 
-      if_init_then_b.create<loop::YieldOp>(loc, if_pred.getResults());
+      if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
     }
     // Init == false, i.e. only pad was visited before and this is the first
     // element in the boundaries of the operand.
     {
       OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
 
-      if_init_else_b.create<loop::YieldOp>(
+      if_init_else_b.create<scf::YieldOp>(
           loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
     }
     return if_init.getResults();
@@ -708,7 +707,7 @@
 
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
-                           loop::LoopOpsDialect, XlaLhloDialect>();
+                           scf::SCFDialect, XlaLhloDialect>();
     target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
                         xla_lhlo::SelectAndScatterOp>();
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
index 982ec4f..c317dc3 100644
--- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
@@ -281,11 +281,9 @@
     // No conversion is needed for the same width integers
     return args.front();
   }
-  // TODO(dfki-ehna): Add other primitive type conversions
-  // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
-  //   return b.create<mlir::FpToSiOp>(loc, result_types,
-  //   args,mlir::None);
-  // }
+  if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
+    return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
+  }
   return nullptr;
 }
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
index a4ffa57..bf66640 100644
--- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
@@ -50,12 +50,6 @@
 template <typename SrcOp>
 bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
                                        Value *out_lhs, Value *out_rhs) {
-  if (!op.broadcast_dimensions().hasValue()) {
-    // Note: the op may still have an implicit broadcast on it, such as
-    // for (tensor<1xf32>, tensor<4xf32>).
-    return false;
-  }
-
   // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
   // replacing the original LHS and RHS args in the source op with the results
   // of the broadcasts.
@@ -79,25 +73,7 @@
 
   auto lhs_rank = lhs_ranked_type.getRank();
   auto rhs_rank = rhs_ranked_type.getRank();
-
-  // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
-  // Use the original op.broadcast_dimensions for the lower rank arg.
-  auto higher_rank_broadcast_dims =
-      GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter);
-  DenseIntElementsAttr lhs_broadcast_dims;
-  DenseIntElementsAttr rhs_broadcast_dims;
-  if (lhs_rank > rhs_rank) {
-    lhs_broadcast_dims = higher_rank_broadcast_dims;
-    rhs_broadcast_dims = op.broadcast_dimensions().getValue();
-  } else if (lhs_rank < rhs_rank) {
-    lhs_broadcast_dims = op.broadcast_dimensions().getValue();
-    rhs_broadcast_dims = higher_rank_broadcast_dims;
-  } else {
-    // This shouldn't happen for legal ops. If the broadcast_dimensions
-    // attribute is set, the ranks should be different.
-    // TODO(scotttodd): Add a custom verification for ops and assert here.
-    return false;
-  }
+  ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
 
   // BroadcastInDimOp must have the same element type for operands and results,
   // so preserve the original output shape and the original input element type.
@@ -105,16 +81,32 @@
   //   broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32>
   //   broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32>
   //   SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
-  ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
-  auto lhs_type =
-      RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
-  auto rhs_type =
-      RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
+  if (lhs_ranked_type.getShape() != op_ranked_type.getShape()) {
+    auto type =
+        RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
+    DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, lhs_rank, rewriter);
+    if (lhs_rank < rhs_rank) {
+      attr = op.broadcast_dimensions().getValue();
+    }
 
-  *out_lhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), lhs_type,
-                                                      lhs, lhs_broadcast_dims);
-  *out_rhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), rhs_type,
-                                                      rhs, rhs_broadcast_dims);
+    lhs =
+        rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, lhs, attr);
+  }
+
+  if (rhs_ranked_type.getShape() != op_ranked_type.getShape()) {
+    auto type =
+        RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
+    DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, rhs_rank, rewriter);
+    if (rhs_rank < lhs_rank) {
+      attr = op.broadcast_dimensions().getValue();
+    }
+
+    rhs =
+        rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, rhs, attr);
+  }
+
+  *out_lhs = lhs;
+  *out_rhs = rhs;
   return true;
 }
 
@@ -359,9 +351,15 @@
 
 void SetupMaterializeBroadcastsLegality(MLIRContext *context,
                                         ConversionTarget *conversionTarget) {
-#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \
-  conversionTarget->addDynamicallyLegalOp<OpType>(      \
-      [](OpType op) { return !op.broadcast_dimensions().hasValue(); });
+#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType)           \
+  conversionTarget->addDynamicallyLegalOp<OpType>([](OpType op) { \
+    if (op.broadcast_dimensions().hasValue()) return false;       \
+    auto l = op.lhs().getType().cast<ShapedType>();               \
+    auto r = op.rhs().getType().cast<ShapedType>();               \
+    if (!l.hasRank() || !r.hasRank()) return false;               \
+    return l.getShape() == r.getShape();                          \
+  });
+
   // Binary elementwise ops.
   ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp);
   ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op);
diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
index ad81cda..9cde6f8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
@@ -23,6 +23,7 @@
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 
 namespace mlir {
+class BufferAssignmentPlacer;
 namespace xla_hlo {
 
 // Collection of rewrite patterns for lowering a general dot product.
@@ -38,9 +39,9 @@
                               MLIRContext *ctx);
 
 // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
-void populateHLOToLHLOConversionPattern(MLIRContext *context,
-                                        OwningRewritePatternList *patterns);
-
+void populateHLOToLHLOConversionPattern(
+    MLIRContext *context, BufferAssignmentPlacer *bufferAssignment,
+    TypeConverter *converter, OwningRewritePatternList *patterns);
 // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
 void populateHLOToLinalgConversionPattern(MLIRContext *context,
                                           OwningRewritePatternList *patterns);
diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc
index 436a3e7..a12bd9e 100644
--- a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc
@@ -251,17 +251,15 @@
 
   // Create the view for this slice size, possible with an affine map to model
   // the offset. The result is cached in the slices_ map.
-  SmallVector<AffineMap, 1> offset_map;
-  if (slice.offset()) {
-    offset_map.push_back(AffineMap::get(
-        /*dimCount=*/1, /*symbolCount=*/0,
-        {getAffineDimExpr(0, builder_.getContext()) + slice.offset()},
-        builder_.getContext()));
-  }
-  auto slice_type = MemRefType::get({slice.size()}, i8_type_, offset_map);
+  // The std.view result type does not carry the static offset: this is not
+  // useful information. Rather, the view op must have the static offset.
+  auto slice_type = MemRefType::get({slice.size()}, i8_type_, {});
 
-  auto slice_view = builder_.create<ViewOp>(
-      alloc_buffer.getLoc(), slice_type, alloc_buffer, /*operands=*/llvm::None);
+  Value byte_shift =
+      builder_.create<ConstantIndexOp>(alloc_buffer.getLoc(), slice.offset());
+  auto slice_view =
+      builder_.create<ViewOp>(alloc_buffer.getLoc(), slice_type, alloc_buffer,
+                              byte_shift, /*sizes=*/ArrayRef<Value>{});
   slices_.insert({slice_key, slice_view});
   return slice_view;
 }
@@ -277,9 +275,12 @@
   Value slice_view = GetOrCreateView(out_slice);
   TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType<MemRefType>(
                                          target_shape, builder_));
+  Value byte_shift =
+      builder_.create<ConstantIndexOp>(builder_.getUnknownLoc(), 0);
   if (slice_view.getType() != out_type)
-    slice_view = builder_.create<ViewOp>(builder_.getUnknownLoc(), out_type,
-                                         slice_view, llvm::None);
+    slice_view =
+        builder_.create<ViewOp>(builder_.getUnknownLoc(), out_type, slice_view,
+                                byte_shift, /*sizes=*/ArrayRef<Value>{});
   return slice_view;
 }
 
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 5cca1e6..ea4ba8d 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -562,6 +562,7 @@
     name = "dynamic_slice_ops_test",
     size = "small",
     srcs = ["dynamic_slice_ops_test.py"],
+    enable_mlir_bridge = True,
     python_version = "PY3",
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@@ -1386,6 +1387,7 @@
     size = "medium",
     srcs = ["fused_batchnorm_test.py"],
     python_version = "PY3",
+    shard_count = 5,
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
     ],
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index bd01319..00ed6d8 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1096,8 +1096,6 @@
             x,
             expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/155097273): Handle complex dtype constants")
   def testExpandDims(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1195,8 +1193,6 @@
         np.full([1, 1, 3, 5], 3., dtype=np.float32),
         expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/155097273): Handle complex dtype constants")
   def testPad(self):
     for dtype, pad_type in itertools.product(
         self.numeric_types, [np.int32, np.int64]):
@@ -1337,8 +1333,6 @@
               ],
               dtype=dtype))
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/155097273): Handle complex dtype constants")
   def testReshape(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1471,8 +1465,6 @@
                [1, 2]],
               dtype=dtype))
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/155097273): Handle complex dtype constants")
   def testTranspose(self):
     for dtype in self.numeric_types:
       self._testBinary(
@@ -1491,8 +1483,6 @@
           np.array([1, 0], dtype=np.int32),
           expected=np.array([[1, 3], [2, 4]], dtype=dtype))
 
-  @test_util.disable_mlir_bridge(
-      "TODO(b/155097273): Handle complex dtype constants")
   def testConjugateTranspose(self):
     for dtype in self.complex_types:
       self._testBinary(
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index ff35587..a1bb64e 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -48,7 +48,8 @@
       {'start': 1, 'end': 2, 'num': 1},
       {'start': 1, 'end': 4, 'num': 3},
       {'start': 0, 'end': 41, 'num': 42})
-  @test_util.disable_mlir_bridge('Requires dynamic shape handling')
+  @test_util.disable_mlir_bridge(
+      'TODO(b/156174708): Dynamic result types not supported')
   def testLinspace(self, start, end, num):
     expected = np.linspace(start, end, num, dtype=np.float32)
     result = self._testTernary(
@@ -182,7 +183,6 @@
           np.array([8, 9], dtype=dtype),
           expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
 
-  @test_util.disable_mlir_bridge('TODO(b/155097273)')
   def testSlice(self):
     for dtype in self.numeric_types:
       self._testTernary(
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index dc11c24..3e36f67 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -601,8 +601,7 @@
           np.array([-1, -0.5, 0, 0.3], dtype=dtype),
           expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
 
-  @test_util.disable_mlir_bridge(
-      "Complex types not supported in CreateDenseElementsAttrFromLiteral")
+  @test_util.disable_mlir_bridge("TODO(b/156135423): Fix ConvertSigmoidOp")
   def testComplexOps(self):
     for dtype in self.complex_types:
 
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index b01c5ae..1f83701 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -51,7 +51,6 @@
         equality_fn = self.assertAllClose
       equality_fn(result, expected, rtol=1e-3)
 
-  @test_util.disable_mlir_bridge('Not supported yet')
   def testAdd(self):
     for dtype in self.numeric_types:
       self._assertOpOutputMatchesExpected(
@@ -109,7 +108,6 @@
                       xla_data_pb2.PrecisionConfig.HIGHEST)
 
   @parameterized.parameters(*PRECISION_VALUES)
-  @test_util.disable_mlir_bridge('Not supported yet')
   def testConv(self, precision):
     for dtype in set(self.float_types).intersection(
         set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
@@ -194,8 +192,6 @@
           args=(np.array([1, 2, 3], dtype=dtype),),
           expected=np.array([-1, -2, -3], dtype=dtype))
 
-  @test_util.disable_mlir_bridge(
-      'Requires XlaPad op shape inference to have static result types')
   def testPad(self):
     for dtype in self.numeric_types:
 
diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index cd52e2f..404f9eb 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -70,6 +70,12 @@
   return *this;
 }
 
+ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning(
+    bool use_spmd_partitioning) {
+  use_spmd_partitioning_ = use_spmd_partitioning;
+  return *this;
+}
+
 ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment(
     const DeviceAssignment& device_assignment) {
   device_assignment_ = device_assignment;
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 360ad02..9a7fdd9 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -77,6 +77,11 @@
   int num_partitions() const { return num_partitions_; }
   ExecutableBuildOptions& set_num_partitions(int num_partitions);
 
+  // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
+  // num_partitions > 1 and XLA is requested to partition the input program.
+  bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
+  ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning);
+
   // If set, this specifies a static device assignment for the computation.
   // Otherwise, the computation will be compiled generically and can be run with
   // any device assignment compatible with the computation's replica and
@@ -104,6 +109,7 @@
   se::DeviceMemoryAllocator* device_allocator_ = nullptr;
   int num_replicas_ = 1;
   int num_partitions_ = 1;
+  bool use_spmd_partitioning_ = false;
   absl::optional<DeviceAssignment> device_assignment_;
   bool alias_passthrough_params_ = false;
 };
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 0b146a4..bd70ce8 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1301,7 +1301,6 @@
     int64 feature_group_count, int64 batch_group_count,
     const PrecisionConfig* precision_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
     TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
     TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
     TF_RETURN_IF_ERROR(
@@ -1314,30 +1313,45 @@
       window_dimensions[i] =
           rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
     }
-    TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
+
+    TF_ASSIGN_OR_RETURN(Window window,
                         ShapeInference::InferWindowFromDimensions(
                             window_dimensions, window_strides, padding,
                             lhs_dilation, rhs_dilation));
-
-    TF_ASSIGN_OR_RETURN(
-        Shape shape, ShapeInference::InferConvolveShape(
-                         *lhs_shape, *rhs_shape, feature_group_count,
-                         batch_group_count, instr.window(), dimension_numbers));
-    *instr.mutable_shape() = shape.ToProto();
-
-    *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
-    instr.set_feature_group_count(feature_group_count);
-    instr.set_batch_group_count(batch_group_count);
-
-    if (precision_config != nullptr) {
-      *instr.mutable_precision_config() = *precision_config;
-    }
-
-    return AddInstruction(std::move(instr), HloOpcode::kConvolution,
-                          {lhs, rhs});
+    TF_ASSIGN_OR_RETURN(Shape shape,
+                        ShapeInference::InferConvolveShape(
+                            *lhs_shape, *rhs_shape, feature_group_count,
+                            batch_group_count, window, dimension_numbers));
+    return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
+                                      padding, lhs_dilation, rhs_dilation,
+                                      dimension_numbers, feature_group_count,
+                                      batch_group_count, precision_config);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
+    const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+    absl::Span<const int64> window_strides,
+    absl::Span<const std::pair<int64, int64>> padding,
+    absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count, int64 batch_group_count,
+    const PrecisionConfig* precision_config) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+
+  *instr.mutable_window() = window;
+  *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
+  instr.set_feature_group_count(feature_group_count);
+  instr.set_batch_group_count(batch_group_count);
+
+  if (precision_config != nullptr) {
+    *instr.mutable_precision_config() = *precision_config;
+  }
+
+  return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
+}
+
 XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
                       const absl::Span<const int64> fft_length) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index bfb97d7..33fe62e 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -491,6 +491,16 @@
                            int64 batch_group_count = 1,
                            const PrecisionConfig* precision_config = nullptr);
 
+  virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
+      const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+      absl::Span<const int64> window_strides,
+      absl::Span<const std::pair<int64, int64>> padding,
+      absl::Span<const int64> lhs_dilation,
+      absl::Span<const int64> rhs_dilation,
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      int64 feature_group_count, int64 batch_group_count,
+      const PrecisionConfig* precision_config);
+
   XlaOp Fft(XlaOp operand, FftType fft_type,
             absl::Span<const int64> fft_length);
 
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py
index ef0caff..6d4482a 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py
@@ -20,6 +20,9 @@
 
 from absl import logging
 
+# Import xla_client to load shared C++ extensions (just CompileOptions at the
+# time of writing).
+from tensorflow.compiler.xla.python import xla_client  # pylint: disable=unused-import
 from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client
 
 
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 499c4e2..3349528 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3846,6 +3846,7 @@
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:core",
         "@llvm-project//llvm:transform_utils",
     ],
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 8c76e91..ce9c8a4 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -91,6 +91,7 @@
     TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
         execution_options.mutable_device_assignment()));
   }
+  execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning());
   for (const AotXlaComputationInstance& instance : computations) {
     TF_RET_CHECK(instance.computation.has_host_program_shape());
     *execution_options.mutable_shape_with_output_layout() =
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index cf64615..57b24e3 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -76,6 +76,7 @@
 
   virtual int64 replica_count() const { return 0; }
   virtual int64 num_cores() const { return 0; }
+  virtual bool use_spmd_partitioning() const { return false; }
 
   // Optional allocator that may be used for allocating temp space on the device
   // during compilation.
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index e21ca01..05364a4 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -109,24 +109,6 @@
     const HloInstruction* hlo,
     const HloToElementGeneratorMap& operand_to_generator) {
   switch (hlo->opcode()) {
-    case HloOpcode::kMap:
-      return [this, hlo, &operand_to_generator](
-                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        std::vector<llvm::Value*> operands;
-        for (int i = 0; i < hlo->operand_count(); i++) {
-          TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
-                              operand_to_generator.at(hlo->operand(i))(index));
-          operands.push_back(operand_value);
-        }
-        return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
-                                             operands, llvm_ir::IrName(hlo));
-      };
-    case HloOpcode::kReduceWindow:
-      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
-        return ir_emitter_->EmitElementalReduceWindow(
-            Cast<HloReduceWindowInstruction>(hlo),
-            operand_to_generator.at(hlo->operand(0)), index);
-      };
     case HloOpcode::kConvolution:
       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
         return ir_emitter_->EmitElementalConvolution(
@@ -134,22 +116,6 @@
             operand_to_generator.at(hlo->operand(0)),
             operand_to_generator.at(hlo->operand(1)), index);
       };
-    case HloOpcode::kReduce:
-      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
-        auto reduce_instr = Cast<HloReduceInstruction>(hlo);
-        std::vector<llvm_ir::ElementGenerator> input_generators;
-        for (const HloInstruction* instr : reduce_instr->inputs()) {
-          input_generators.push_back(operand_to_generator.at(instr));
-        }
-
-        std::vector<llvm_ir::ElementGenerator> initial_value_generators;
-        for (const HloInstruction* instr : reduce_instr->init_values()) {
-          initial_value_generators.push_back(operand_to_generator.at(instr));
-        }
-        return ir_emitter_->EmitElementalReduce(
-            reduce_instr, std::move(input_generators),
-            std::move(initial_value_generators), index);
-      };
     default:
       return ElementalIrEmitter::MakeElementGenerator(hlo,
                                                       operand_to_generator);
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index e3fba93..5c9f667 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -44,6 +44,12 @@
   StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
                                   llvm::Value* value) override;
 
+  StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
+      const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
+      absl::string_view name) override {
+    return ir_emitter_->EmitThreadLocalCall(callee, parameters, name);
+  }
+
   IrEmitter* ir_emitter_;
 };
 
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index c19fa77..2b715bf 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -695,101 +695,6 @@
   return Status::OK();
 }
 
-llvm::Value* IrEmitter::EmitElementalMap(
-    const HloMapInstruction& map_instr,
-    absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
-  return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
-                                            elemental_operands, name);
-}
-
-StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
-    const HloReduceWindowInstruction* reduce_window,
-    const llvm_ir::ElementGenerator& input_generator,
-    const llvm_ir::IrArray::Index& index) {
-  const HloInstruction* operand = reduce_window->operand(0);
-  const Window& window = reduce_window->window();
-
-  // We fold inputs into the accumulator and initialize it to
-  // the initial value on the reduce_window.
-  PrimitiveType operand_element_type = operand->shape().element_type();
-  llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
-      llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
-      "reduce_window_accumulator_address", &b_,
-      MinimumAlignmentForPrimitiveType(operand_element_type));
-  Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
-        accumulator_address);
-
-  llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
-  std::vector<int64> window_size;
-  for (const auto& dim : window.dimensions()) {
-    window_size.push_back(dim.size());
-  }
-  const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
-      ShapeUtil::MakeShape(operand_element_type, window_size), "window");
-  CHECK_EQ(window_index.size(), index.size());
-
-  SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
-
-  std::vector<llvm::Value*> input_multi_index(index.size());
-  llvm::Value* in_bounds_condition = nullptr;
-  for (size_t i = 0; i < index.size(); ++i) {
-    llvm::Value* strided_index =
-        NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
-    input_multi_index[i] = NSWSub(
-        NSWAdd(strided_index,
-               NSWMul(window_index[i],
-                      b_.getInt64(window.dimensions(i).window_dilation()))),
-        b_.getInt64(window.dimensions(i).padding_low()));
-
-    // We need to verify that we are not in the dilated base area.
-    llvm::Value* dilation_condition =
-        ICmpEQ(SRem(input_multi_index[i],
-                    b_.getInt64(window.dimensions(i).base_dilation())),
-               b_.getInt64(0));
-    if (in_bounds_condition == nullptr) {
-      in_bounds_condition = dilation_condition;
-    } else {
-      in_bounds_condition = And(in_bounds_condition, dilation_condition);
-    }
-
-    // Apply base dilation to the index.
-    input_multi_index[i] =
-        SDiv(input_multi_index[i],
-             b_.getInt64(window.dimensions(i).base_dilation()));
-
-    // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we
-    // are in the padding so that we can skip the computation. That is
-    // equivalent to input_multi_index[i] < bound as an *unsigned* comparison,
-    // since a negative value will wrap to a large positive value.
-    llvm::Value* index_condition =
-        ICmpULT(input_multi_index[i],
-                b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
-    if (in_bounds_condition == nullptr) {
-      in_bounds_condition = index_condition;
-    } else {
-      in_bounds_condition = And(in_bounds_condition, index_condition);
-    }
-  }
-  CHECK(in_bounds_condition != nullptr);
-
-  llvm_ir::LlvmIfData if_data =
-      llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
-  SetToFirstInsertPoint(if_data.true_block, &b_);
-
-  // We are not in the padding, so carry out the computation.
-  llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(),
-                                      b_.getInt64Ty());
-  TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
-                      input_generator(input_index));
-  llvm::Value* result = EmitScalarReturningThreadLocalCall(
-      *reduce_window->to_apply(), {Load(accumulator_address), input_value},
-      "reducer_function");
-  Store(result, accumulator_address);
-
-  SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
-  return Load(accumulator_address);
-}
-
 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
   // Pseudo code for reduce window:
   //
@@ -2099,108 +2004,6 @@
   return true;
 }
 
-StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
-    const HloReduceInstruction* reduce,
-    std::vector<llvm_ir::ElementGenerator> input_generators,
-    std::vector<llvm_ir::ElementGenerator> initial_value_generators,
-    const llvm_ir::IrArray::Index& index) {
-  const Shape& out_shape = reduce->shape();
-  bool is_variadic = !out_shape.IsArray();
-  int accumulators_count = 1;
-  if (is_variadic) {
-    CHECK(out_shape.IsTuple());
-    accumulators_count = out_shape.tuple_shapes_size();
-  }
-
-  absl::Span<const int64> reduced_dimensions(reduce->dimensions());
-
-  std::vector<llvm::Value*> accumulator_addrs;
-  std::vector<llvm::Type*> accumulator_types;
-  for (int i = 0; i < accumulators_count; i++) {
-    const Shape& element_shape =
-        is_variadic ? out_shape.tuple_shapes(i) : out_shape;
-    PrimitiveType accumulator_type = element_shape.element_type();
-    llvm::Type* accumulator_llvm_type =
-        llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
-    accumulator_types.push_back(accumulator_llvm_type);
-
-    // Initialize an accumulator with init_value.
-    llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
-        accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_,
-        MinimumAlignmentForPrimitiveType(accumulator_type));
-    TF_ASSIGN_OR_RETURN(
-        llvm::Value* const init_value,
-        initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType())));
-    Store(init_value, accumulator_addr);
-    accumulator_addrs.push_back(accumulator_addr);
-  }
-
-  // The enclosing loops go over all the target elements. Now we have to compute
-  // the actual target element. For this, we build a new loop nest to iterate
-  // over all the reduction dimensions in the argument.
-  // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
-  // are placed for each dimension in dimensions, and all the rest are nullptrs.
-  llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
-  const HloInstruction* arg = reduce->operand(0);
-  std::vector<llvm::Value*> input_multi_index =
-      loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
-                                         "reduction_dim");
-
-  SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
-
-  // Build a full index for the input argument, using input_multi_index as the
-  // base. In input_multi_index only the reduction dimensions are filled in. We
-  // fill in the rest of the dimensions with induction Value*s taken from
-  // 'index' which iterates over the target array.  See the high-level
-  // description in the XLA documentation for details.
-  llvm_ir::IrArray::Index::const_iterator it = index.begin();
-
-  for (auto& i : input_multi_index) {
-    if (i == nullptr) {
-      i = *it++;
-    }
-  }
-  CHECK(index.end() == it);
-  llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
-                                      b_.getInt64Ty());
-
-  std::vector<llvm::Value*> reduction_operands;
-  for (llvm::Value* accum : accumulator_addrs) {
-    llvm::Value* accum_value = Load(accum);
-    reduction_operands.push_back(accum_value);
-  }
-
-  for (int i = 0; i < accumulators_count; i++) {
-    TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
-                        input_generators[i](input_index));
-    reduction_operands.push_back(input_element);
-  }
-
-  std::vector<llvm::Value*> results = EmitThreadLocalCall(
-      *reduce->to_apply(), reduction_operands, "reduce_function");
-
-  CHECK(results.size() == accumulators_count);
-  for (int i = 0; i < accumulators_count; i++) {
-    Store(results[i], accumulator_addrs[i]);
-  }
-  SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
-
-  if (is_variadic) {
-    // Emit a structure, as that what the LoopEmitter expects.
-    llvm::Value* returned_structure = llvm::UndefValue::get(
-        llvm::StructType::get(b_.getContext(), accumulator_types));
-    for (int i = 0; i < accumulators_count; i++) {
-      llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
-      returned_structure =
-          b_.CreateInsertValue(returned_structure, accumulator_value, i);
-    }
-    return returned_structure;
-  } else {
-    CHECK_EQ(accumulator_addrs.size(), 1);
-    return Load(accumulator_addrs[0]);
-  }
-}
-
 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
   auto arg = reduce->mutable_operand(0);
   auto init_value = reduce->mutable_operand(1);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index cc5aa3f..24524c6 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -58,6 +58,8 @@
 // functions.
 class IrEmitter : public DfsHloVisitorWithDefault,
                   public IrBuilderMixin<IrEmitter> {
+  friend class CpuElementalIrEmitter;
+
  public:
   using GeneratorForOperandIrArrays =
       std::function<std::vector<llvm_ir::IrArray>()>;
@@ -113,28 +115,12 @@
   // Emit an LLVM global variable for every constant buffer allocation.
   Status EmitConstantGlobals();
 
-  // Emit code to map one element according to `map_instr`.
-  llvm::Value* EmitElementalMap(
-      const HloMapInstruction& map_instr,
-      absl::Span<llvm::Value* const> elemental_operands,
-      absl::string_view name);
-  // Emit code to emit the element at `index` for a reduce window instruction.
-  StatusOr<llvm::Value*> EmitElementalReduceWindow(
-      const HloReduceWindowInstruction* reduce_window,
-      const llvm_ir::ElementGenerator& input_generator,
-      const llvm_ir::IrArray::Index& index);
   // Emit code to emit the element at `index` for a convolution instruction.
   StatusOr<llvm::Value*> EmitElementalConvolution(
       const HloConvolutionInstruction* convolution,
       const llvm_ir::ElementGenerator& input_generator,
       const llvm_ir::ElementGenerator& kernel_generator,
       const llvm_ir::IrArray::Index& index);
-  // Emit code to emit the element at `index` for a reduce instruction.
-  StatusOr<llvm::Value*> EmitElementalReduce(
-      const HloReduceInstruction* reduce,
-      std::vector<llvm_ir::ElementGenerator> input_generators,
-      std::vector<llvm_ir::ElementGenerator> initial_value_generator,
-      const llvm_ir::IrArray::Index& index);
 
  protected:
   //
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 30300b8..8cb660d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -2422,6 +2422,43 @@
                  -> StatusOr<llvm::Value*> {
         return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
       };
+    case HloOpcode::kMap:
+      return [this, hlo, &operand_to_generator](
+                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+        std::vector<llvm::Value*> operands;
+        for (int i = 0; i < hlo->operand_count(); i++) {
+          TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+                              operand_to_generator.at(hlo->operand(i))(index));
+          operands.push_back(operand_value);
+        }
+        std::vector<llvm_ir::ElementGenerator> input_generators;
+        for (const HloInstruction* instr : hlo->operands()) {
+          input_generators.push_back(operand_to_generator.at(instr));
+        }
+        return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
+      };
+    case HloOpcode::kReduceWindow:
+      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+        return EmitElementalReduceWindow(
+            Cast<HloReduceWindowInstruction>(hlo),
+            operand_to_generator.at(hlo->operand(0)),
+            operand_to_generator.at(hlo->operand(1)), index);
+      };
+    case HloOpcode::kReduce:
+      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+        auto reduce_instr = Cast<HloReduceInstruction>(hlo);
+        std::vector<llvm_ir::ElementGenerator> input_generators;
+        for (const HloInstruction* instr : reduce_instr->inputs()) {
+          input_generators.push_back(operand_to_generator.at(instr));
+        }
+
+        std::vector<llvm_ir::ElementGenerator> initial_value_generators;
+        for (const HloInstruction* instr : reduce_instr->init_values()) {
+          initial_value_generators.push_back(operand_to_generator.at(instr));
+        }
+        return EmitElementalReduce(reduce_instr, std::move(input_generators),
+                                   std::move(initial_value_generators), index);
+      };
     default:
       return [hlo](const IrArray::Index& index) {
         return Unimplemented("Unhandled opcode for elemental IR emission: %s",
@@ -2451,4 +2488,215 @@
   return complex;
 }
 
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
+    const HloMapInstruction* map_instr,
+    absl::Span<llvm::Value* const> elemental_operands) {
+  TF_ASSIGN_OR_RETURN(
+      std::vector<llvm::Value*> values,
+      EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
+                          llvm_ir::IrName(map_instr)));
+  CHECK_EQ(values.size(), 1);
+  return values[0];
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
+    const HloReduceWindowInstruction* reduce_window,
+    const llvm_ir::ElementGenerator& input_generator,
+    const llvm_ir::ElementGenerator& initial_value_generator,
+    const llvm_ir::IrArray::Index& index) {
+  // Pseudocode:
+  // for each index I in output
+  //   value = init_value
+  //   for each index W in window
+  //     for each dimension i from 0 to rank - 1
+  //       (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
+  //     if I in bounds of input
+  //       value = function(value, input[I])
+  //     output[O] = value
+  const HloInstruction* operand = reduce_window->operand(0);
+  const Window& window = reduce_window->window();
+
+  PrimitiveType operand_element_type = operand->shape().element_type();
+  llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
+      llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
+      "reduce_window_accum_ptr", b_);
+  {
+    TF_ASSIGN_OR_RETURN(
+        llvm::Value* const init_value,
+        initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
+    Store(init_value, accum_ptr);
+  }
+
+  llvm::Type* index_type = index.GetType();
+  auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+    return index.GetConstantWithIndexType(c);
+  };
+
+  llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
+  std::vector<int64> window_size;
+  for (const auto& dim : window.dimensions()) {
+    window_size.push_back(dim.size());
+  }
+  const IrArray::Index window_index = loops.AddLoopsForShape(
+      ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+  CHECK_EQ(window_index.size(), index.size());
+
+  SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
+
+  std::vector<llvm::Value*> input_multi_index(index.size());
+  llvm::Value* in_bounds = b_->getInt1(true);
+  for (size_t i = 0; i < index.size(); ++i) {
+    llvm::Value* stridden_index =
+        NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
+    input_multi_index[i] = NSWSub(
+        NSWAdd(
+            stridden_index,
+            NSWMul(window_index[i],
+                   index_typed_const(window.dimensions(i).window_dilation()))),
+        index_typed_const(window.dimensions(i).padding_low()));
+
+    // We need to verify that we are not in the dilated base area.
+    llvm::Value* dilation_condition =
+        ICmpEQ(SRem(input_multi_index[i],
+                    index_typed_const(window.dimensions(i).base_dilation())),
+               index_typed_const(0));
+    in_bounds = And(in_bounds, dilation_condition);
+
+    // Apply base dilation to the index.
+    input_multi_index[i] =
+        SDiv(input_multi_index[i],
+             index_typed_const(window.dimensions(i).base_dilation()));
+
+    // We must check whether 0 <= input_multi_index[i] < bound, as
+    // otherwise we are in the pad and so can skip the computation. This
+    // comparison is equivalent to the unsigned comparison
+    // input_multi_index[i] < bound, as a negative value wraps to a large
+    // positive value.
+    in_bounds = And(in_bounds,
+                    ICmpULT(input_multi_index[i],
+                            index_typed_const(operand->shape().dimensions(i))));
+  }
+
+  llvm_ir::LlvmIfData if_data =
+      llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
+  SetToFirstInsertPoint(if_data.true_block, b_);
+
+  // We are not in pad, so do the computation.
+  IrArray::Index input_index(input_multi_index, operand->shape(), index_type);
+  TF_ASSIGN_OR_RETURN(llvm::Value * input_value, input_generator(input_index));
+  TF_ASSIGN_OR_RETURN(
+      std::vector<llvm::Value*> accum_values,
+      EmitThreadLocalCall(*reduce_window->to_apply(),
+                          {Load(accum_ptr), input_value}, "reducer_function"));
+  CHECK_EQ(accum_values.size(), 1);
+  Store(accum_values[0], accum_ptr);
+
+  SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
+  return Load(accum_ptr);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
+    const HloReduceInstruction* reduce,
+    std::vector<llvm_ir::ElementGenerator> input_generators,
+    std::vector<llvm_ir::ElementGenerator> initial_value_generators,
+    const llvm_ir::IrArray::Index& index) {
+  const Shape& out_shape = reduce->shape();
+  bool is_variadic = !out_shape.IsArray();
+  int accumulators_count = 1;
+  if (is_variadic) {
+    CHECK(out_shape.IsTuple());
+    accumulators_count = out_shape.tuple_shapes_size();
+  }
+
+  absl::Span<const int64> reduced_dimensions(reduce->dimensions());
+
+  std::vector<llvm::Value*> accumulator_addrs;
+  std::vector<llvm::Type*> accumulator_types;
+  llvm::Type* index_type = index.GetType();
+  for (int i = 0; i < accumulators_count; i++) {
+    const Shape& element_shape =
+        is_variadic ? out_shape.tuple_shapes(i) : out_shape;
+    PrimitiveType accumulator_type = element_shape.element_type();
+    llvm::Type* accumulator_llvm_type =
+        llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
+    accumulator_types.push_back(accumulator_llvm_type);
+
+    // Initialize an accumulator with init_value.
+    llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+        accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
+    TF_ASSIGN_OR_RETURN(
+        llvm::Value* const init_value,
+        initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
+    Store(init_value, accumulator_addr);
+    accumulator_addrs.push_back(accumulator_addr);
+  }
+
+  // The enclosing loops go over all the target elements. Now we have to compute
+  // the actual target element. For this, we build a new loop nest to iterate
+  // over all the reduction dimensions in the argument.
+  // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
+  // are placed for each dimension in dimensions, and all the rest are nullptrs.
+  llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
+  const HloInstruction* arg = reduce->operand(0);
+  std::vector<llvm::Value*> input_multi_index =
+      loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
+                                         "reduction_dim");
+
+  SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
+
+  // Build a full index for the input argument, using input_multi_index as the
+  // base. In input_multi_index only the reduction dimensions are filled in. We
+  // fill in the rest of the dimensions with induction Value*s taken from
+  // 'index' which iterates over the target array.  See the high-level
+  // description in the XLA documentation for details.
+  auto it = index.begin();
+
+  for (auto& i : input_multi_index) {
+    if (i == nullptr) {
+      i = *it++;
+    }
+  }
+  CHECK(index.end() == it);
+  llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
+                                      index_type);
+
+  std::vector<llvm::Value*> reduction_operands;
+  for (llvm::Value* accum : accumulator_addrs) {
+    llvm::Value* accum_value = Load(accum);
+    reduction_operands.push_back(accum_value);
+  }
+
+  for (int i = 0; i < accumulators_count; i++) {
+    TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
+                        input_generators[i](input_index));
+    reduction_operands.push_back(input_element);
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<llvm::Value*> results,
+      EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
+                          "reduce_function"));
+
+  CHECK(results.size() == accumulators_count);
+  for (int i = 0; i < accumulators_count; i++) {
+    Store(results[i], accumulator_addrs[i]);
+  }
+  SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
+
+  if (is_variadic) {
+    // Emit a structure, as that what the LoopEmitter expects.
+    llvm::Value* returned_structure = llvm::UndefValue::get(
+        llvm::StructType::get(b()->getContext(), accumulator_types));
+    for (int i = 0; i < accumulators_count; i++) {
+      llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
+      returned_structure =
+          b()->CreateInsertValue(returned_structure, accumulator_value, i);
+    }
+    return returned_structure;
+  } else {
+    CHECK_EQ(accumulator_addrs.size(), 1);
+    return Load(accumulator_addrs[0]);
+  }
+}
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 94e8f1d..06a9d7b 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -17,12 +17,17 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
 
 #include <unordered_map>
+#include <vector>
 
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Value.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -220,6 +225,26 @@
       const HloToElementGeneratorMap& operand_to_generator,
       const llvm_ir::IrArray::Index& dot_result_index);
 
+  virtual StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
+      const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
+      absl::string_view name) = 0;
+
+  StatusOr<llvm::Value*> EmitElementalMap(
+      const HloMapInstruction* map_instr,
+      absl::Span<llvm::Value* const> elemental_operands);
+
+  StatusOr<llvm::Value*> EmitElementalReduceWindow(
+      const HloReduceWindowInstruction* reduce_window,
+      const llvm_ir::ElementGenerator& input_generator,
+      const llvm_ir::ElementGenerator& initial_value_generator,
+      const llvm_ir::IrArray::Index& index);
+
+  StatusOr<llvm::Value*> EmitElementalReduce(
+      const HloReduceInstruction* reduce,
+      std::vector<llvm_ir::ElementGenerator> input_generators,
+      std::vector<llvm_ir::ElementGenerator> initial_value_generators,
+      const llvm_ir::IrArray::Index& index);
+
   llvm::IRBuilder<>* const b_;
 
   llvm::Module* module_;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 61bc412..bff8734 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -17,15 +17,15 @@
     "tf_cuda_library",
 )
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
-load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
+load(
+    "@local_config_rocm//rocm:build_defs.bzl",
+    "if_rocm",
+    "if_rocm_is_configured",
+)
 load(
     "//tensorflow/core/platform/default:cuda_build_defs.bzl",
     "if_cuda_is_configured",
 )
-load(
-    "@local_config_rocm//rocm:build_defs.bzl",
-    "if_rocm_is_configured",
-)
 load("//tensorflow:tensorflow.bzl", "if_nccl")
 
 package(
@@ -86,6 +86,7 @@
     name = "gpu_types",
     hdrs = ["gpu_types.h"],
     deps = [
+        "//tensorflow/compiler/xla:types",
         "@com_google_absl//absl/types:variant",
     ],
 )
@@ -405,6 +406,7 @@
     deps = [
         ":buffer_allocations",
         ":gpu_executable_run_options",
+        ":gpu_types",
         ":hlo_execution_profiler",
         "//tensorflow/compiler/xla:executable_run_options",
         "//tensorflow/compiler/xla/service:hlo",
@@ -684,7 +686,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_pass",
-        "//tensorflow/core:autotuning_proto_cc",
+        "//tensorflow/core/protobuf:autotuning_proto_cc",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/core/util/proto:proto_utils",
@@ -720,7 +722,7 @@
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:hlo_pass",
-        "//tensorflow/core:autotuning_proto_cc",
+        "//tensorflow/core/protobuf:autotuning_proto_cc",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:stream_executor_no_cuda",
@@ -1674,7 +1676,7 @@
     protodeps = [
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:hlo_proto",
-        "//tensorflow/core:autotuning_proto",
+        "//tensorflow/core/protobuf:autotuning_proto",
     ],
 )
 
@@ -1685,8 +1687,8 @@
     deps = [
         ":gpu_autotuning_proto_cc",
         "//tensorflow/compiler/xla:debug_options_flags",
-        "//tensorflow/core:autotuning_proto_cc",
         "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core/protobuf:autotuning_proto_cc",
         "@com_google_absl//absl/container:flat_hash_map",
     ],
 )
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
index 974db02..485aff0 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
@@ -104,11 +104,9 @@
   return isa_version;
 }
 
-StatusOr<std::pair<std::string, std::vector<uint8>>>
-AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
-                                    llvm::Module* llvm_module,
-                                    GpuVersion gpu_version,
-                                    se::StreamExecutor* stream_exec) {
+StatusOr<GpuTargetBinary> AMDGPUCompiler::CompileTargetBinary(
+    const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
+    se::StreamExecutor* stream_exec) {
   if (rocdl_dir_.empty()) {
     // Compute rocdl_dir_ just once and cache it in this member.
     rocdl_dir_ = GetROCDLDir(module->config());
@@ -129,7 +127,7 @@
     user_post_optimization_hook_(*llvm_module);
   }
 
-  return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco));
+  return GpuTargetBinary{"", std::move(hsaco)};
 }
 
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
index acc5e02..9033585 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
@@ -39,7 +39,7 @@
 
   GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
 
-  StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
+  StatusOr<GpuTargetBinary> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
       GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
 
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index e31f459..5e7d89c 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -50,7 +50,7 @@
   }
 }
 
-Status ConditionalThunk::Initialize(const GpuExecutable& executable,
+Status ConditionalThunk::Initialize(const GpuTargetBinary& target_binary,
                                     se::StreamExecutor* executor) {
   if (branch_index_is_bool_) {
     TF_RET_CHECK(branch_thunks_.size() == 2);
@@ -58,7 +58,7 @@
     TF_RET_CHECK(!branch_thunks_.empty());
   }
   for (auto& branch_thunk : branch_thunks_) {
-    TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor));
+    TF_RETURN_IF_ERROR(branch_thunk->Initialize(target_binary, executor));
   }
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
index 404e213..ba69e1a 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -52,7 +52,7 @@
   ConditionalThunk& operator=(const ConditionalThunk&) = delete;
 
   void ComputeAnnotations() override;
-  Status Initialize(const GpuExecutable& executable,
+  Status Initialize(const GpuTargetBinary& target_binary,
                     se::StreamExecutor* executor) override;
   Status ExecuteOnStream(const ExecuteParams& params) override;
 
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index c6df786..1be0b1b 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -305,168 +305,5 @@
   return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
 }
 
-llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
-    const HloInstruction* hlo,
-    const HloToElementGeneratorMap& operand_to_generator) {
-  switch (hlo->opcode()) {
-    case HloOpcode::kMap:
-      return [=, &operand_to_generator](
-                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        TF_RET_CHECK(!hlo->operands().empty())
-            << "Zero operand map not implemented in GPU backend.";
-        TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
-        std::vector<llvm::Value*> operand_elements;
-        for (HloInstruction* operand : hlo->operands()) {
-          TF_ASSIGN_OR_RETURN(llvm::Value * value,
-                              operand_to_generator.at(operand)(index));
-          operand_elements.push_back(value);
-        }
-        return compute_nested_(*hlo->to_apply(), operand_elements);
-      };
-    case HloOpcode::kReduceWindow:
-      // Pseudocode:
-      // for each index I in output
-      //   value = init_value
-      //   for each index W in window
-      //     for each dimension i from 0 to rank - 1
-      //       (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
-      //     if I in bounds of input
-      //       value = function(value, input[I])
-      //     output[O] = value
-      return [=, &operand_to_generator](
-                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
-        const HloInstruction* operand = hlo->operand(0);
-        const Window& window = hlo->window();
-
-        PrimitiveType operand_element_type = operand->shape().element_type();
-        llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
-            llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
-            "reduce_window_accum_ptr", b_);
-        {
-          TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
-                              operand_to_generator.at(hlo->operand(1))(
-                                  IrArray::Index(index.GetType())));
-          Store(init_value, accum_ptr);
-        }
-
-        llvm::Type* index_type = index.GetType();
-        auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
-          return index.GetConstantWithIndexType(c);
-        };
-
-        llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
-        std::vector<int64> window_size;
-        for (const auto& dim : window.dimensions()) {
-          window_size.push_back(dim.size());
-        }
-        const IrArray::Index window_index = loops.AddLoopsForShape(
-            ShapeUtil::MakeShape(operand_element_type, window_size), "window");
-        CHECK_EQ(window_index.size(), index.size());
-
-        SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
-
-        std::vector<llvm::Value*> input_multi_index(index.size());
-        llvm::Value* in_bounds = b_->getInt1(true);
-        for (size_t i = 0; i < index.size(); ++i) {
-          llvm::Value* stridden_index = NSWMul(
-              index[i], index_typed_const(window.dimensions(i).stride()));
-          input_multi_index[i] = NSWSub(
-              NSWAdd(stridden_index,
-                     NSWMul(window_index[i],
-                            index_typed_const(
-                                window.dimensions(i).window_dilation()))),
-              index_typed_const(window.dimensions(i).padding_low()));
-
-          // We need to verify that we are not in the dilated base area.
-          llvm::Value* dilation_condition = ICmpEQ(
-              SRem(input_multi_index[i],
-                   index_typed_const(window.dimensions(i).base_dilation())),
-              index_typed_const(0));
-          in_bounds = And(in_bounds, dilation_condition);
-
-          // Apply base dilation to the index.
-          input_multi_index[i] =
-              SDiv(input_multi_index[i],
-                   index_typed_const(window.dimensions(i).base_dilation()));
-
-          // We must check whether 0 <= input_multi_index[i] < bound, as
-          // otherwise we are in the pad and so can skip the computation. This
-          // comparison is equivalent to the unsigned comparison
-          // input_multi_index[i] < bound, as a negative value wraps to a large
-          // positive value.
-          in_bounds =
-              And(in_bounds,
-                  ICmpULT(input_multi_index[i],
-                          index_typed_const(operand->shape().dimensions(i))));
-        }
-
-        llvm_ir::LlvmIfData if_data =
-            llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
-        SetToFirstInsertPoint(if_data.true_block, b_);
-
-        // We are not in pad, so do the computation.
-        IrArray::Index input_index(input_multi_index, operand->shape(),
-                                   index_type);
-        TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
-                            operand_to_generator.at(operand)(input_index));
-        TF_ASSIGN_OR_RETURN(
-            llvm::Value * accum_value,
-            compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value}));
-        Store(accum_value, accum_ptr);
-
-        SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
-        return Load(accum_ptr);
-      };
-    case HloOpcode::kReduce:
-      // TODO(b/118332391): This should be supported.
-      CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce";
-      return [=, &operand_to_generator](
-                 const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
-        const HloInstruction* operand = hlo->operand(0);
-        llvm::Value* accum_ptr =
-            b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
-                hlo->shape().element_type(), module_));
-        llvm::Type* index_type = output_index.GetType();
-        TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
-                            operand_to_generator.at(hlo->operand(1))(
-                                IrArray::Index(index_type)));
-        b()->CreateStore(init_value, accum_ptr);
-
-        llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
-        std::vector<llvm::Value*> input_multi_index =
-            loops.AddLoopsForShapeOnDimensions(
-                operand->shape(), hlo->dimensions(), "reduction_dim");
-        if (!ShapeUtil::IsScalar(hlo->shape())) {
-          // Here only input_multi_index[hlo->dimensions()] are non-null, so we
-          // must set the rest.
-          size_t j = 0;
-          for (auto& i : input_multi_index) {
-            if (i == nullptr) {
-              i = output_index[j++];
-            }
-          }
-          CHECK_EQ(output_index.size(), j);
-        }
-        llvm_ir::IrArray::Index input_index(
-            input_multi_index, hlo->operand(0)->shape(), index_type);
-
-        SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
-        TF_ASSIGN_OR_RETURN(
-            llvm::Value * input_value,
-            operand_to_generator.at(hlo->operand(0))(input_index));
-        TF_ASSIGN_OR_RETURN(
-            llvm::Value * accum_value,
-            compute_nested_(*hlo->to_apply(),
-                            {b()->CreateLoad(accum_ptr), input_value}));
-        b()->CreateStore(accum_value, accum_ptr);
-        SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
-        return b()->CreateLoad(accum_ptr);
-      };
-    default:
-      return ElementalIrEmitter::MakeElementGenerator(hlo,
-                                                      operand_to_generator);
-  }
-}
-
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index c8a58a2..3c4e9f7 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -47,10 +47,6 @@
                         llvm::Module* module, llvm::IRBuilder<>* b,
                         NestedComputer compute_nested);
 
-  llvm_ir::ElementGenerator MakeElementGenerator(
-      const HloInstruction* hlo,
-      const HloToElementGeneratorMap& operand_to_generator) override;
-
  protected:
   StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
                                            llvm::Value* lhs_value,
@@ -92,6 +88,17 @@
   StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type,
                                         llvm::Value* value) override;
 
+  StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
+      const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
+      absl::string_view) override {
+    // TODO(b/118332391): Supported variadic return values.
+    auto result = compute_nested_(callee, parameters);
+    if (!result.ok()) {
+      return result.status();
+    }
+    return std::vector<llvm::Value*>{result.ValueOrDie()};
+  }
+
   llvm::Value* EmitThreadId() override;
 
  private:
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index 0a97f66..aacc9de 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -39,9 +39,9 @@
   body_thunk_sequence_->ComputeAnnotations();
 }
 
-Status ForThunk::Initialize(const GpuExecutable& executable,
+Status ForThunk::Initialize(const GpuTargetBinary& target_binary,
                             se::StreamExecutor* executor) {
-  TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
+  TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h
index 57402f7..57657b6 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h
@@ -38,7 +38,7 @@
   ForThunk& operator=(const ForThunk&) = delete;
 
   void ComputeAnnotations() override;
-  Status Initialize(const GpuExecutable& executable,
+  Status Initialize(const GpuTargetBinary& target_binary,
                     se::StreamExecutor* executor) override;
   Status ExecuteOnStream(const ExecuteParams& params) override;
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5f6dfd7..533ff52 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -565,8 +565,7 @@
 
   GpuVersion gpu_version = GetGpuVersion(stream_exec);
 
-  using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
-  TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
+  TF_ASSIGN_OR_RETURN(GpuTargetBinary backend_result,
                       CompileTargetBinary(module.get(), &llvm_module,
                                           gpu_version, stream_exec));
 
@@ -578,6 +577,11 @@
                             thunk_schedule->ToString());
   }
 
+  std::vector<Thunk*> thunks;
+  for (Thunk* thunk : thunk_schedule->TotalOrder()) {
+    thunks.push_back(thunk);
+  }
+
   std::unique_ptr<HloProfileIndexMap> profile_index_map;
   std::unique_ptr<HloProfilePrinterData> profile_printer;
 
@@ -597,14 +601,19 @@
   }
 
   auto* gpu_executable = new GpuExecutable(
-      backend_result.first, backend_result.second, gpu_version,
-      std::move(thunk_schedule), std::move(module),
-      std::move(buffer_assignment), std::move(profile_printer),
-      std::move(profile_index_map));
+      std::move(backend_result), gpu_version, std::move(thunk_schedule),
+      std::move(module), std::move(buffer_assignment),
+      std::move(profile_printer), std::move(profile_index_map));
   if (embed_ir_in_executable) {
     DCHECK_NE("", ir_module_string_before_opt);
     gpu_executable->set_ir_module_string(ir_module_string_before_opt);
   }
+
+  for (Thunk* thunk : thunks) {
+    TF_RETURN_IF_ERROR(
+        thunk->Initialize(gpu_executable->target_binary(), stream_exec));
+  }
+
   return std::unique_ptr<Executable>(gpu_executable);
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index b52af53..deb5d78 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -74,10 +74,9 @@
 
   virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
 
-  virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
-  CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
-                      GpuVersion gpu_version,
-                      se::StreamExecutor* stream_exec) = 0;
+  virtual StatusOr<GpuTargetBinary> CompileTargetBinary(
+      const HloModule* hlo_module, llvm::Module* llvm_module,
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec) = 0;
 
   Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 2df6b50..ebd3630 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -52,16 +52,15 @@
 // Implementation note: HLO profiling is always enabled for GPU executables,
 // since we can use timers around thunks.
 GpuExecutable::GpuExecutable(
-    const string& text, const std::vector<uint8>& binary,
-    GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
+    GpuTargetBinary target_binary, GpuVersion gpu_version,
+    std::unique_ptr<const ThunkSchedule> thunk_schedule,
     std::shared_ptr<HloModule> hlo_module,
     std::shared_ptr<const BufferAssignment> assignment,
     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
     std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
     : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
                  std::move(hlo_profile_index_map)),
-      text_(text),
-      binary_(binary),
+      target_binary_(std::move(target_binary)),
       gpu_version_(gpu_version),
       thunk_schedule_(std::move(thunk_schedule)),
       assignment_(std::move(assignment)) {
@@ -176,7 +175,6 @@
     // module, we won't get any data, but that's probably an OK trade-off.
     ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
 
-    TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
     int32 stream_no =
         thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
     se::Stream* stream =
@@ -469,7 +467,7 @@
 int64 GpuExecutable::SizeOfGeneratedCodeInBytes() {
   // Non-empty PTX but empty cubin: compilation must have failed, return
   // "unknown".
-  if (binary().empty() && !text_.empty()) {
+  if (binary().empty() && !text().empty()) {
     return -1;
   }
   return binary().size();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 045a36c..29441c6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -52,8 +52,7 @@
   // We need to share ownership of hlo_module and assignment with profiler to
   // safely keep a reference to these objects during tracing period, thus they
   // are passed as shared pointers.
-  GpuExecutable(const string& text, const std::vector<uint8>& binary,
-                GpuVersion gpu_version,
+  GpuExecutable(GpuTargetBinary target_binary, GpuVersion gpu_version,
                 std::unique_ptr<const ThunkSchedule> thunk_schedule,
                 std::shared_ptr<HloModule> hlo_module,
                 std::shared_ptr<const BufferAssignment> assignment,
@@ -73,12 +72,14 @@
 
   // Returns the compiled code for the computation. The compiled code is PTX in
   // Cuda and unused empty string in ROCm.
-  const string& text() const { return text_; }
+  const string& text() const { return target_binary_.text; }
 
   // Returns the binary stored in this GpuExecutable. The binary is cubin in
   // Cuda, and HSA code object in ROCm. It may be empty, in which case
   // compilation is left up to the GPU driver.
-  const std::vector<uint8>& binary() const { return binary_; }
+  const std::vector<uint8>& binary() const { return target_binary_.binary; }
+
+  const GpuTargetBinary& target_binary() const { return target_binary_; }
 
   // ExecuteAsyncOnStream will fail if the compute capability of the stream
   // doesn't match the compute capability passed to this object's constructor.
@@ -131,14 +132,7 @@
   // This string should be modified only before ExecuteOnStream.
   string ir_module_string_;
 
-  // The compiled code for the computation.
-  const string text_;
-
-  // The GPU machine code for the computation, targeting GPUs at
-  // compute_capability_.
-  //
-  // May be empty, in which case we leave compilation up to the GPU driver.
-  const std::vector<uint8> binary_;
+  const GpuTargetBinary target_binary_;
 
   // The GPU version for compute compatibility check.
   GpuVersion gpu_version_;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_types.h b/tensorflow/compiler/xla/service/gpu/gpu_types.h
index 1c51040..5c8b809 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_types.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_types.h
@@ -16,7 +16,11 @@
 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
 
+#include <string>
+#include <vector>
+
 #include "absl/types/variant.h"
+#include "tensorflow/compiler/xla/types.h"
 
 namespace xla {
 namespace gpu {
@@ -25,6 +29,19 @@
 // it comprises a pair of integers denoting major and minor version.
 // On ROCm platform, it comprises one integer for AMD GCN ISA version.
 using GpuVersion = absl::variant<std::pair<int, int>, int>;
+
+// A struct to carry around compiled results by the GPU assembler.
+struct GpuTargetBinary {
+  GpuTargetBinary(const GpuTargetBinary& other) = delete;
+  GpuTargetBinary(GpuTargetBinary&& other) = default;
+
+  // The text format of the compiled result, e.g. PTX.
+  std::string text;
+
+  // The actual compiled binary.
+  std::vector<tensorflow::uint8> binary;
+};
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index d976b5d..0b5010e 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -18,7 +18,6 @@
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -42,7 +41,7 @@
       kernel_name_(kernel_name),
       unroll_factor_(unroll_factor) {}
 
-Status KernelThunk::Initialize(const GpuExecutable& executable,
+Status KernelThunk::Initialize(const GpuTargetBinary& target_binary,
                                se::StreamExecutor* executor) {
   tensorflow::mutex_lock lock(mutex_);
 
@@ -55,8 +54,10 @@
   if (kernel_cache_.end() == it) {
     TF_ASSIGN_OR_RETURN(
         std::unique_ptr<se::KernelBase> kernel,
-        CreateKernel(kernel_name_, args_.size(), executable.text(),
-                     executable.binary(), executor));
+        CreateKernel(kernel_name_, args_.size(), target_binary.text,
+                     target_binary.binary, executor));
+    CHECK(!target_binary.binary.empty());
+    CHECK(kernel);
 
     kernel_cache_.emplace(executor, std::move(kernel));
   }
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index 8835188..97a1d08 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -35,8 +35,6 @@
 namespace xla {
 namespace gpu {
 
-class GpuExecutable;
-
 // This class stores everything that StreamExecutor needs for launching a
 // kernel. It implements the ExecuteOnStream interface for GpuExecutable to
 // invoke the corresponding kernel.
@@ -58,7 +56,7 @@
   int unroll_factor() const { return unroll_factor_; }
   void SetLaunchDimensions(const LaunchDimensions& launch_dims);
 
-  Status Initialize(const GpuExecutable& executable,
+  Status Initialize(const GpuTargetBinary& target_binary,
                     se::StreamExecutor* executor) override;
   Status ExecuteOnStream(const ExecuteParams& params) override;
 
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 0196267..cf6fe92 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -295,11 +295,9 @@
   return std::make_pair(cc_major, cc_minor);
 }
 
-StatusOr<std::pair<std::string, std::vector<uint8>>>
-NVPTXCompiler::CompileTargetBinary(const HloModule* module,
-                                   llvm::Module* llvm_module,
-                                   GpuVersion gpu_version,
-                                   se::StreamExecutor* stream_exec) {
+StatusOr<GpuTargetBinary> NVPTXCompiler::CompileTargetBinary(
+    const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
+    se::StreamExecutor* stream_exec) {
   std::pair<int, int> compute_capability =
       absl::get<std::pair<int, int>>(gpu_version);
 
@@ -340,8 +338,7 @@
       stream_exec, ptx, compute_capability.first, compute_capability.second,
       module->config());
 
-  return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
-                                                    std::move(cubin));
+  return GpuTargetBinary{std::move(ptx), std::move(cubin)};
 }
 
 std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index e69be94..ec550b5 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -48,7 +48,7 @@
 
   GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
 
-  StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
+  StatusOr<GpuTargetBinary> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
       GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
 
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
index 025ca60..bd26033 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
@@ -34,10 +34,10 @@
   }
 }
 
-Status SequentialThunk::Initialize(const GpuExecutable& executable,
+Status SequentialThunk::Initialize(const GpuTargetBinary& target_binary,
                                    se::StreamExecutor* executor) {
   for (auto& thunk : thunks_) {
-    TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor));
+    TF_RETURN_IF_ERROR(thunk->Initialize(target_binary, executor));
   }
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
index 3abb82c..b547566 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
@@ -40,7 +40,7 @@
   const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
 
   void ComputeAnnotations() override;
-  Status Initialize(const GpuExecutable& executable,
+  Status Initialize(const GpuTargetBinary& target_binary,
                     se::StreamExecutor* executor) override;
   Status ExecuteOnStream(const ExecuteParams& params) override;
 
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index e9be41b..7aff9ca 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -22,6 +22,7 @@
 #include "tensorflow/compiler/xla/executable_run_options.h"
 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -30,8 +31,6 @@
 namespace xla {
 namespace gpu {
 
-class GpuExecutable;
-
 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
 //
@@ -97,7 +96,7 @@
   // This may be called multiple times.  Its main purpose is to give us a chance
   // to do initialization outside of ExecuteOnStream() so that the
   // time spent initializing doesn't count towards our execution profile.
-  virtual Status Initialize(const GpuExecutable& /*executable*/,
+  virtual Status Initialize(const GpuTargetBinary& /*target_binary*/,
                             se::StreamExecutor* /*executor*/) {
     return Status::OK();
   }
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index 4134cd3..2650508 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -45,11 +45,11 @@
   body_thunk_sequence_->ComputeAnnotations();
 }
 
-Status WhileThunk::Initialize(const GpuExecutable& executable,
+Status WhileThunk::Initialize(const GpuTargetBinary& target_binary,
                               se::StreamExecutor* executor) {
   TF_RETURN_IF_ERROR(
-      condition_thunk_sequence_->Initialize(executable, executor));
-  TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
+      condition_thunk_sequence_->Initialize(target_binary, executor));
+  TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h
index 31db01b..77ee010 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h
@@ -47,7 +47,7 @@
   WhileThunk& operator=(const WhileThunk&) = delete;
 
   void ComputeAnnotations() override;
-  Status Initialize(const GpuExecutable& executable,
+  Status Initialize(const GpuTargetBinary& target_binary,
                     se::StreamExecutor* executor) override;
   Status ExecuteOnStream(const ExecuteParams& params) override;
 
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index de65ed9..9722d5c 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -420,6 +420,8 @@
     if (execution_options->num_partitions() > 0) {
       module_config.set_num_partitions(execution_options->num_partitions());
     }
+    module_config.set_use_spmd_partitioning(
+        execution_options->use_spmd_partitioning());
     if (execution_options->has_device_assignment()) {
       TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
                           DeviceAssignment::Deserialize(
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index b31a9ae..833d0fe 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -128,6 +128,11 @@
   }
   int64 num_partitions() const { return num_partitions_; }
 
+  void set_use_spmd_partitioning(bool use_spmd_partitioning) {
+    use_spmd_partitioning_ = use_spmd_partitioning;
+  }
+  bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
+
   // Return a string which unambiguously represents all the fields of this data
   // structure. Used for generating a cache key for storing the compiled
   // executable.
@@ -199,6 +204,14 @@
 
   std::vector<std::vector<int64>>* mutable_dot_config() { return &dot_config_; }
 
+  absl::Span<const std::vector<std::vector<int64>>> layout_config() const {
+    return layout_config_;
+  }
+
+  std::vector<std::vector<std::vector<int64>>>* mutable_layout_config() {
+    return &layout_config_;
+  }
+
  private:
   // If you add new members, be sure to update compilation_cache_key.
 
@@ -216,6 +229,10 @@
   // The number of partitions (model parallelism) to compile this binary for.
   int64 num_partitions_ = 1;
 
+  // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
+  // needs to partition the module.
+  bool use_spmd_partitioning_ = false;
+
   // The target maximum parallelism at which to partition HLOs for parallel
   // execution on the CPU backend.
   int64 intra_op_parallelism_threads_ = -1;
@@ -232,6 +249,9 @@
   FusionConfigCollection fusion_config_collection_ =
       FusionConfigCollection::kOff;
 
+  // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto
+  // similar to backend config.
+
   // Custom fusion configuration, where fusion_config_[c][v] control if node v
   // in computation c must be fused to all its consumers (true) or not (false).
   std::vector<std::vector<bool>> fusion_config_;
@@ -240,6 +260,10 @@
   // how to convert dot operation v (sorted topologically and by computation) to
   // convolution.
   std::vector<std::vector<int64>> dot_config_;
+
+  // Layout configuration, where layout_config_[v][i] controls the layout
+  // decision i of operation v.
+  std::vector<std::vector<std::vector<int64>>> layout_config_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 360c8e5..d15a365 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -662,9 +662,11 @@
           shape_size_function_(bitcast->operand(0)->shape())) {
     return InternalError(
         "Bitcast cannot have different shape sizes of output (%d) and operand "
-        "(%d)",
+        "(%d) (%s) (%s)",
         shape_size_function_(bitcast->shape()),
-        shape_size_function_(bitcast->operand(0)->shape()));
+        shape_size_function_(bitcast->operand(0)->shape()),
+        bitcast->shape().ToString(true),
+        bitcast->operand(0)->shape().ToString(true));
   }
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 84654bf..13699f3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -951,7 +951,8 @@
                 if (!Shape::Equal()
                          .IgnoreDynamicDimension()
                          .MinorToMajorOnlyInLayout()(instruction_subshape,
-                                                     buffer->shape())) {
+                                                     buffer->shape()) &&
+                    instruction->opcode() != HloOpcode::kBitcast) {
                   return InternalError(
                       "Layout of instruction %s at index {%s} does not match "
                       "source LogicalBuffer %s: %s vs %s",
@@ -1798,13 +1799,6 @@
   // potential bugs in the layout assignment pass that may accidentally use the
   // existing layout.
   for (HloInstruction* instruction : computation->instructions()) {
-    if (instruction->opcode() == HloOpcode::kBitcast) {
-      // bitcasts are inherently layout sensitive and so a bitcast instruction
-      // present in the IR before layout assignment is a bug.
-      return InternalError(
-          "Unexpected bitcast operation seen during layout assignment: %s.",
-          instruction->ToString());
-    }
     // Some instructions carry mandatory layouts in their shape.
     if (instruction->opcode() != HloOpcode::kInfeed &&
         !IsLayoutConstrainedCustomCall(instruction) &&
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 304a80c..6e57524 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -814,27 +814,6 @@
   EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
 }
 
-TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
-  auto builder = HloComputation::Builder(TestName());
-  auto constant0 = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
-          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
-  builder.AddInstruction(
-      HloInstruction::CreateBitcast(constant0->shape(), constant0));
-  auto m = CreateNewVerifiedModule();
-  m->AddEntryComputation(builder.Build());
-
-  ComputationLayout computation_layout(
-      m->entry_computation()->ComputeProgramShape());
-  LayoutAssignment layout_assignment(&computation_layout);
-  Status error_status = layout_assignment.Run(m.get()).status();
-  EXPECT_FALSE(error_status.ok());
-  EXPECT_THAT(
-      error_status.error_message(),
-      ::testing::HasSubstr(
-          "Unexpected bitcast operation seen during layout assignment"));
-}
-
 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
   // Pin non matching layouts to parameter and root.
   const char* module_str = R"(
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index ef8ddfc..c80646e 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -112,6 +112,8 @@
   }
   execution_options.set_num_replicas(build_options.num_replicas());
   execution_options.set_num_partitions(build_options.num_partitions());
+  execution_options.set_use_spmd_partitioning(
+      build_options.use_spmd_partitioning());
   if (build_options.has_device_assignment()) {
     TF_CHECK_OK(build_options.device_assignment().Serialize(
         execution_options.mutable_device_assignment()));
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
index cd679f7..a57e430 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD
+++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
@@ -185,11 +185,11 @@
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgToLLVM",
         "@llvm-project//mlir:LinalgTransforms",
-        "@llvm-project//mlir:LoopOps",
-        "@llvm-project//mlir:LoopOpsTransforms",
         "@llvm-project//mlir:LoopsToGPUPass",
         "@llvm-project//mlir:NVVMDialect",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 33d3690..847ad91 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -31,9 +31,9 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"  // from @llvm-project
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"  // from @llvm-project
 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
-#include "mlir/Dialect/LoopOps/LoopOps.h"  // from @llvm-project
-#include "mlir/Dialect/LoopOps/Passes.h"  // from @llvm-project
-#include "mlir/Dialect/LoopOps/Transforms.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/Passes.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
@@ -45,6 +45,7 @@
 #include "mlir/IR/Region.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "mlir/Transforms/LoopUtils.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
@@ -60,34 +61,6 @@
 
 using ::mlir::xla_lhlo::FusionOp;
 
-// Following are some small transformations that are required to clean up code
-// after lowering from linalg to loops.
-
-// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that
-// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp.
-// This is needed, as these ops are not closed from above and hence nested pass
-// managers can not be applied.
-struct NestedHloRegionsConverter
-    : public mlir::PassWrapper<NestedHloRegionsConverter,
-                               ::mlir::FunctionPass> {
-  void runOnFunction() override {
-    auto& ctx = getContext();
-    mlir::OwningRewritePatternList patterns;
-    mlir::ConversionTarget target(ctx);
-    target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
-    ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
-
-    getFunction().walk([&](mlir::Operation* op) {
-      if (op->getNumRegions() == 0) {
-        return;
-      }
-      if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
-        signalPassFailure();
-      }
-    });
-  }
-};
-
 // Replaces a FusionOp by the operations contained in its region.
 struct FusionOpRemover
     : public mlir::PassWrapper<FusionOpRemover, ::mlir::FunctionPass> {
@@ -132,7 +105,7 @@
     // No store operation found. Continue search outside of the parallel
     // loop if block is in a parallel loop.
     if (auto parallelOp =
-            llvm::dyn_cast<mlir::loop::ParallelOp>(block->getParentOp())) {
+            llvm::dyn_cast<mlir::scf::ParallelOp>(block->getParentOp())) {
       return findStore(parallelOp.getOperation(), matches);
     }
     return {};
@@ -388,8 +361,8 @@
 struct FuseInnerParallelLoops
     : public mlir::PassWrapper<FuseInnerParallelLoops, mlir::FunctionPass> {
   void runOnFunction() override {
-    getFunction().walk([](mlir::loop::ParallelOp op) {
-      mlir::loop::naivelyFuseParallelOps(op.region());
+    getFunction().walk([](mlir::scf::ParallelOp op) {
+      mlir::scf::naivelyFuseParallelOps(op.region());
     });
   }
 };
@@ -401,7 +374,7 @@
   void runOnOperation() override {
     mlir::Operation* module = getOperation();
 
-    module->walk([&](mlir::loop::ParallelOp op) {
+    module->walk([&](mlir::scf::ParallelOp op) {
       unsigned num_loops = op.getNumLoops();
       std::vector<unsigned> combinedLoops;
       combinedLoops.reserve(num_loops);
@@ -436,8 +409,10 @@
     tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
   }
 
-  // First, lower bodies of LHLO operations that contain HLO ops.
-  pm.addPass(absl::make_unique<NestedHloRegionsConverter>());
+  // Legalize from HLO to LHLO.
+  pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass());
+  // Moving `AllocOp`s and inserting missing `DeallocOp`s
+  pm.addPass(::mlir::createBufferPlacementPass());
   // Next, we can strip the outer fusion operation.
   pm.addPass(absl::make_unique<FusionOpRemover>());
   // Remove unnecessary LHLO copies.
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
index 35ac3b2..667cdef 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
@@ -549,10 +549,11 @@
   }
 
   // TODO(b/137624192): Add profiling support.
+
   return {absl::make_unique<GpuExecutable>(
-      ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
-      emission_context.releaseHloModule(), std::move(buffer_assignment),
-      nullptr, nullptr)};
+      xla::gpu::GpuTargetBinary{ptx, cubin}, GetGpuVersion(stream_exec),
+      std::move(thunk_schedule), emission_context.releaseHloModule(),
+      std::move(buffer_assignment), nullptr, nullptr)};
 }
 
 StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index ab71c30..2ed5e70 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -313,6 +313,8 @@
     if (execution_options->num_partitions() > 0) {
       config->set_num_partitions(execution_options->num_partitions());
     }
+    config->set_use_spmd_partitioning(
+        execution_options->use_spmd_partitioning());
     config->set_seed(execution_options->seed());
     config->set_launch_id(execution_options->launch_id());
     config->set_debug_options(execution_options->debug_options());
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index c8a242c..1ad1f83 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1909,7 +1909,7 @@
         # This test is tagged "manual" because it requires multiple GPUs, and
         # Forge only supports single-GPU tests.  Guitar skips "manual" tests
         # unless they're also tagged "guitar".
-        "guitar",
+        #  "guitar",  # Re-enable after b/156405690 is fixed.
         "manual",
         "multi_gpu",
         "no_oss",
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 1947f51..16ed022 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -55,16 +55,15 @@
 
   GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
 
-  StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
+  StatusOr<GpuTargetBinary> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
-      GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec) override {
     if (user_post_optimization_hook_) {
       user_post_optimization_hook_(*llvm_module);
     }
 
     std::vector<uint8> compiled_results;
-    return std::pair<std::string, std::vector<uint8>>(
-        "", std::move(compiled_results));
+    return GpuTargetBinary{"", std::move(compiled_results)};
   }
 };
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index a015af6..f4b08f45 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -333,6 +333,10 @@
 
   // Used to identify a set of programs that should be launch together.
   int32 launch_id = 10;
+
+  // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
+  // num_partitions > 1 and XLA is requested to partition the input program.
+  bool use_spmd_partitioning = 11;
 }
 
 message GetDeviceHandlesRequest {
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 28ef943..6b4874a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -105,19 +105,17 @@
 # For platform specific build config
 load(
     "//tensorflow/core/platform:build_config.bzl",
-    "tf_additional_all_protos",
     "tf_additional_lib_deps",
     "tf_additional_test_deps",
     "tf_jspb_proto_library",
     "tf_kernel_tests_linkstatic",
     "tf_lib_proto_parsing_deps",
     "tf_portable_deps_no_runtime",
+    "tf_portable_proto_lib",
     "tf_proto_library",
-    "tf_proto_library_cc",
     "tf_protos_all_impl",
     "tf_protos_grappler_impl",
     "tf_protos_profiler_impl",
-    "tf_pyclif_proto_library",
 )
 load(
     "//tensorflow/core/platform:rules_cc.bzl",
@@ -180,18 +178,18 @@
 # filegroup; e.g.  ones with individual proto_library targets.
 # LINT.IfChange
 COMMON_PROTO_SRCS = [
-    "protobuf/bfc_memory_map.proto",
-    "protobuf/config.proto",
-    "protobuf/cluster.proto",
-    "protobuf/debug.proto",
-    "protobuf/device_filters.proto",
-    "protobuf/device_properties.proto",
-    "protobuf/graph_debug_info.proto",
-    "protobuf/queue_runner.proto",
-    "protobuf/rewriter_config.proto",
-    "protobuf/tensor_bundle.proto",
-    "protobuf/saver.proto",
-    "protobuf/verifier_config.proto",
+    "//tensorflow/core/protobuf:bfc_memory_map.proto",
+    "//tensorflow/core/protobuf:config.proto",
+    "//tensorflow/core/protobuf:cluster.proto",
+    "//tensorflow/core/protobuf:debug.proto",
+    "//tensorflow/core/protobuf:device_filters.proto",
+    "//tensorflow/core/protobuf:device_properties.proto",
+    "//tensorflow/core/protobuf:graph_debug_info.proto",
+    "//tensorflow/core/protobuf:queue_runner.proto",
+    "//tensorflow/core/protobuf:rewriter_config.proto",
+    "//tensorflow/core/protobuf:tensor_bundle.proto",
+    "//tensorflow/core/protobuf:saver.proto",
+    "//tensorflow/core/protobuf:verifier_config.proto",
 ]
 
 EXAMPLE_PROTO_SRCS = [
@@ -238,7 +236,7 @@
 ]
 
 ERROR_CODES_PROTO_SRCS = [
-    "protobuf/error_codes.proto",
+    "//tensorflow/core/protobuf:error_codes.proto",
     "//tensorflow/core/lib/core:error_codes.proto",
 ]
 # LINT.ThenChange(//tensorflow/core/portable_proto_config.asciipb)
@@ -251,11 +249,13 @@
     cc_api_version = 2,
     make_default_target_header_only = True,
     protodeps = [
-        ":core_protos",
-        ":error_codes_proto_impl",
         "//tensorflow/core/example:protos_all",
         "//tensorflow/core/framework:protos_all",
         "//tensorflow/core/lib/core:error_codes_proto",
+        "//tensorflow/core/profiler/protobuf:xplane_proto",
+        "//tensorflow/core/profiler:profiler_options_proto",
+        "//tensorflow/core/protobuf:error_codes_proto_impl",
+        "//tensorflow/core/protobuf:for_core_protos",
         "//tensorflow/core/util:protos_all",
         "//tensorflow/core/util:test_log_proto_impl",
     ],
@@ -1270,7 +1270,7 @@
         "//tensorflow/core/platform:mobile_srcs_no_runtime",
         "//tensorflow/core/public:mobile_srcs_no_runtime",
         "//tensorflow/core/util:mobile_srcs_no_runtime",
-        "//tensorflow/core/util/ctc:android_srcs",
+        "//tensorflow/core/util/ctc:mobile_srcs",
     ] + glob(
         [
             "client/**/*.cc",
@@ -1300,12 +1300,12 @@
         "//tensorflow/core/common_runtime/eager:srcs",
         "//tensorflow/core/framework:mobile_srcs_only_runtime",
         "//tensorflow/core/graph:mobile_srcs_only_runtime",
-        "//tensorflow/core/kernels:android_srcs",
+        "//tensorflow/core/kernels:mobile_srcs",
         "//tensorflow/core/lib/io:mobile_srcs_only_runtime",
         "//tensorflow/core/profiler:mobile_srcs",
         "//tensorflow/core/public:mobile_srcs_only_runtime",
         "//tensorflow/core/util/sparse:mobile_srcs_only_runtime",
-        "//tensorflow/core/util/tensor_bundle:android_srcs",
+        "//tensorflow/core/util/tensor_bundle:mobile_srcs",
         "//tensorflow/core/util:mobile_srcs_only_runtime",
 
         # Sources for which we already have granular targets.
@@ -1378,10 +1378,9 @@
     ],
     visibility = ["//visibility:public"],
     deps = [
-        ":protos_all_cc_impl",
         "//tensorflow/core/util:stats_calculator_portable",
         "//tensorflow/core:mobile_additional_lib_deps",
-    ] + tf_portable_deps_no_runtime(),
+    ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(),
     alwayslink = 1,
 )
 
@@ -1603,20 +1602,13 @@
     [
         alias(
             name = "protobuf_%s_pyclif%s" % (proto_name, target_suffix),
-            actual = ":protobuf/%s_pyclif%s" % (proto_name, target_suffix),
+            actual = "//tensorflow/core/protobuf:%s_pyclif%s" % (proto_name, target_suffix),
             visibility = ["//visibility:public"],
         )
         for target_suffix in [
             "",
             "_pb2",
         ]
-    ] + [
-        tf_pyclif_proto_library(
-            name = "protobuf/%s_pyclif" % proto_name,
-            proto_lib = ":protos_all",
-            proto_srcfile = "protobuf/%s.proto" % proto_name,
-            visibility = ["//visibility:public"],
-        ),
     ]
     for proto_name in [
         "config",
@@ -1630,77 +1622,74 @@
 # -----------------------------------------------------------------------------
 # Internal targets
 
-tf_proto_library(
+alias(
     name = "autotuning_proto",
-    srcs = ["protobuf/autotuning.proto"],
-    cc_api_version = 2,
-    make_default_target_header_only = True,
+    actual = "//tensorflow/core/protobuf:autotuning_proto",
     visibility = [
         "//tensorflow:internal",
     ],
 )
 
-tf_proto_library(
+alias(
+    name = "autotuning_proto_cc",
+    actual = "//tensorflow/core/protobuf:autotuning_proto_cc",
+    visibility = [
+        "//tensorflow:internal",
+    ],
+)
+
+alias(
     name = "conv_autotuning_proto",
-    srcs = ["protobuf/conv_autotuning.proto"],
-    cc_api_version = 2,
-    make_default_target_header_only = True,
-    protodeps = [
-        "//tensorflow/stream_executor:dnn_proto",
-    ],
+    actual = "//tensorflow/core/protobuf:conv_autotuning_proto",
     visibility = [
         "//tensorflow:internal",
     ],
 )
 
-tf_proto_library_cc(
-    name = "worker_proto",
-    srcs = ["protobuf/worker.proto"],
-    cc_api_version = 2,
-    protodeps = tf_additional_all_protos(),
-    visibility = ["//visibility:public"],
-)
-
-tf_proto_library_cc(
-    name = "worker_service_proto",
-    srcs = ["protobuf/worker_service.proto"],
-    has_services = 1,
-    cc_api_version = 2,
-    cc_stubby_versions = ["2"],
-    protodeps = [":worker_proto"],
+alias(
+    name = "conv_autotuning_proto_cc",
+    actual = "//tensorflow/core/protobuf:conv_autotuning_proto_cc",
     visibility = [
         "//tensorflow:internal",
     ],
 )
 
-tf_proto_library_cc(
-    name = "master_proto",
-    srcs = ["protobuf/master.proto"],
-    cc_api_version = 2,
-    protodeps = tf_additional_all_protos(),
-    visibility = ["//tensorflow:internal"],
-)
-
-tf_proto_library_cc(
-    name = "master_service_proto",
-    srcs = ["protobuf/master_service.proto"],
-    has_services = 1,
-    cc_api_version = 2,
-    cc_stubby_versions = ["2"],
-    protodeps = [":master_proto"],
+alias(
+    name = "worker_proto_cc",
+    actual = "//tensorflow/core/protobuf:worker_proto_cc",
     visibility = [
         "//tensorflow:internal",
     ],
 )
 
-tf_proto_library_cc(
-    name = "eager_service_proto",
-    srcs = ["protobuf/eager_service.proto"],
-    has_services = 1,
-    cc_api_version = 2,
-    cc_grpc_version = 1,
-    cc_stubby_versions = ["2"],
-    protodeps = tf_additional_all_protos(),
+alias(
+    name = "worker_service_proto_cc",
+    actual = "//tensorflow/core/protobuf:worker_service_proto_cc",
+    visibility = [
+        "//tensorflow:internal",
+    ],
+)
+
+alias(
+    name = "master_proto_cc",
+    actual = "//tensorflow/core/protobuf:master_proto_cc",
+    visibility = [
+        "//learning/brain/frameworks/uptc:__subpackages__",
+        "//tensorflow:internal",
+    ],
+)
+
+alias(
+    name = "master_service_proto_cc",
+    actual = "//tensorflow/core/protobuf:master_service_proto_cc",
+    visibility = [
+        "//tensorflow:internal",
+    ],
+)
+
+alias(
+    name = "eager_service_proto_cc",
+    actual = "//tensorflow/core/protobuf:eager_service_proto_cc",
     visibility = [
         "//tensorflow:internal",
     ],
@@ -2112,49 +2101,14 @@
     ],
 )
 
-tf_proto_library(
+alias(
     name = "error_codes_proto_impl",
-    srcs = ["protobuf/error_codes.proto"],
-    cc_api_version = 2,
-    make_default_target_header_only = True,
+    actual = "//tensorflow/core/protobuf:error_codes_proto_impl",
 )
 
-tf_proto_library(
-    name = "core_protos",
-    srcs = COMMON_PROTO_SRCS + [
-        # Protos which are not needed on mobile builds, but should be included
-        # in protos_all.
-        #
-        # Note that some protos are in neither core_proto_srcs nor this
-        # filegroup; e.g. ones with individual proto_library targets.
-        "protobuf/control_flow.proto",
-        # TODO(ebrevdo): Re-enable once CriticalSection is in core.
-        # "protobuf/critical_section.proto",
-        "protobuf/data/experimental/snapshot.proto",
-        "protobuf/debug_event.proto",
-        "protobuf/meta_graph.proto",
-        "protobuf/named_tensor.proto",
-        "protobuf/remote_tensor_handle.proto",
-        "protobuf/saved_model.proto",
-        "protobuf/saved_object_graph.proto",
-        "protobuf/struct.proto",
-        "protobuf/tensorflow_server.proto",
-        "protobuf/trackable_object_graph.proto",
-        "protobuf/transport_options.proto",
-    ],
-    cc_api_version = 2,
-    make_default_target_header_only = True,
-    protodeps = [
-        ":error_codes_proto_impl",
-        "//tensorflow/core/example:protos_all",
-        "//tensorflow/core/framework:protos_all",
-        "//tensorflow/core/lib/core:error_codes_proto",
-        "//tensorflow/core/profiler/protobuf:xplane_proto",
-        "//tensorflow/core/profiler:profiler_options_proto",
-        "//tensorflow/core/util:protos_all",
-        "//tensorflow/core/util:test_log_proto_impl",
-    ],
-    visibility = ["//visibility:private"],
+alias(
+    name = "error_codes_proto_impl_cc",
+    actual = "//tensorflow/core/protobuf:error_codes_proto_impl_cc",
 )
 
 alias(
@@ -2446,13 +2400,9 @@
     visibility = ["//visibility:public"],
 )
 
-tf_proto_library_cc(
-    name = "replay_log_proto",
-    srcs = ["protobuf/replay_log.proto"],
-    cc_api_version = 2,
-    protodeps = [
-        ":master_proto",
-    ] + tf_additional_all_protos(),
+alias(
+    name = "replay_log_proto_cc",
+    actual = "//tensorflow/core/protobuf:replay_log_proto_cc",
     visibility = [
         "//tensorflow:internal",
     ],
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index 484bdee..016896b 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -243,6 +243,7 @@
         "memory_types.h",
         "mkl_cpu_allocator.h",
         "mkl_layout_pass.h",
+        "mkl_tfconversion_pass.h",
         "optimization_registry.h",
         "partitioning_utils.h",
         "placer.h",
@@ -1028,9 +1029,13 @@
 cc_library(
     name = "mkl_layout_pass",
     srcs = ["mkl_layout_pass.cc"],
-    hdrs = ["mkl_layout_pass.h"],
+    hdrs = [
+        "mkl_layout_pass.h",
+        "//tensorflow/core/graph:mkl_graph_util_header",
+    ],
     copts = tf_copts(),
     deps = [
+        ":function",
         ":optimization_registry",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
@@ -1043,9 +1048,13 @@
 cc_library(
     name = "mkl_tfconversion_pass",
     srcs = ["mkl_tfconversion_pass.cc"],
-    hdrs = ["mkl_tfconversion_pass.h"],
+    hdrs = [
+        "mkl_tfconversion_pass.h",
+        "//tensorflow/core/graph:mkl_graph_util_header",
+    ],
     copts = tf_copts(),
     deps = [
+        ":function",
         ":optimization_registry",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
@@ -2286,7 +2295,7 @@
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:scope",
         "//tensorflow/core/kernels:cwise_op",
-    ] + if_mkl([":mkl_array_ops_op_lib"]),
+    ] + if_mkl(["//tensorflow/core:mkl_array_ops_op_lib"]),
 )
 
 tf_cc_test(
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index 13630a0..7850978 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -98,7 +98,7 @@
 
 Status EagerExecutor::SyncExecute(EagerNode* node) {
   if (Async()) {
-    return errors::Internal("Executor does not support sync execution");
+    return errors::Internal("Executor does not support async execution");
   }
   if (node->AsAsync() != nullptr) {
     return errors::Internal("Executor does not support executing async nodes");
diff --git a/tensorflow/core/common_runtime/eager/execute_node_test.cc b/tensorflow/core/common_runtime/eager/execute_node_test.cc
index 970307d..99f0303 100644
--- a/tensorflow/core/common_runtime/eager/execute_node_test.cc
+++ b/tensorflow/core/common_runtime/eager/execute_node_test.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/core/common_runtime/eager/execute_node.h"
 
+#include <memory>
+
 #include "tensorflow/core/common_runtime/composite_device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
@@ -49,9 +51,12 @@
   StaticDeviceMgr device_mgr(
       DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
   Device* device0 = device_mgr.ListDevices().at(0);
-  StaticDeviceMgr remote_device_mgr(
+  auto remote_device_mgr = absl::make_unique<DynamicDeviceMgr>();
+  std::vector<std::unique_ptr<Device>> remote_devices;
+  remote_devices.emplace_back(
       DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
-  Device* device1 = remote_device_mgr.ListDevices().at(0);
+  TF_ASSERT_OK(remote_device_mgr->AddDevices(std::move(remote_devices)));
+  Device* device1 = remote_device_mgr->ListDevices().at(0);
 
   Status s;
   std::unique_ptr<CompositeDevice> composite_device =
@@ -65,6 +70,17 @@
       tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
       &device_mgr, false, nullptr, nullptr, nullptr);
 
+  // Set a RemoteMgr to the EagerContext.
+  auto remote_mgr = absl::make_unique<eager::RemoteMgr>(
+      /*is_master=*/true, ctx);
+  TF_ASSERT_OK(ctx->InitializeRemoteMaster(
+      /*server=*/nullptr, /*worker_env=*/nullptr,
+      /*worker_session=*/nullptr, /*remote_eager_workers=*/nullptr,
+      std::move(remote_device_mgr), /*remote_contexts=*/{},
+      EagerContext::NewContextId(),
+      /*r=*/nullptr, &device_mgr, /*keep_alive_secs*/ 600,
+      /*cluster_flr=*/nullptr, std::move(remote_mgr)));
+
   DataType dtype = DT_FLOAT;
   Tensor t0(dtype, TensorShape({}));
   // Create two local TensorHandles
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
index 2d4ae33..f233980 100644
--- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
@@ -17,7 +17,7 @@
 #include <unordered_map>
 
 #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
-#include "tensorflow/core/common_runyime/mkl_layout_pass.h"
+#include "tensorflow/core/common_runtime/mkl_layout_pass.h"
 #include "tensorflow/core/graph/mkl_graph_util.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/util/mkl_util.h"
diff --git a/tensorflow/core/data/service/master_impl.cc b/tensorflow/core/data/service/master_impl.cc
index 6e2c95c..336ab06 100644
--- a/tensorflow/core/data/service/master_impl.cc
+++ b/tensorflow/core/data/service/master_impl.cc
@@ -169,7 +169,11 @@
   if (job != nullptr) {
     TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode,
                                            request->dataset_id()));
-    response->set_job_id((*job)->job_id());
+    int64 job_id = (*job)->job_id();
+    response->set_job_id(job_id);
+    VLOG(3) << "Found existing job for name=" << request->job_name()
+            << ", index=" << request->job_name_index()
+            << ". job_id: " << job_id;
     return Status::OK();
   }
   int64 job_id;
@@ -177,6 +181,8 @@
                                request->job_name(), &job_id));
   named_jobs_[key] = jobs_[job_id];
   response->set_job_id(job_id);
+  VLOG(3) << "Created job " << job_id << " for dataset "
+          << request->dataset_id() << " and name " << request->job_name();
   return Status::OK();
 }
 
diff --git a/tensorflow/core/data/service/master_impl.h b/tensorflow/core/data/service/master_impl.h
index de25ea0..e8b70e8 100644
--- a/tensorflow/core/data/service/master_impl.h
+++ b/tensorflow/core/data/service/master_impl.h
@@ -75,7 +75,7 @@
     }
 
     std::string DebugString() {
-      return absl::StrCat("id: ", worker_id_, "address: ", address_);
+      return absl::StrCat("id: ", worker_id_, " address: ", address_);
     }
 
    private:
diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc
index 7395244..8d00825 100644
--- a/tensorflow/core/data/service/worker_impl.cc
+++ b/tensorflow/core/data/service/worker_impl.cc
@@ -84,6 +84,7 @@
 
 Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
     EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+  VLOG(3) << "Received request to process task " << task_def.task_id();
   standalone::Dataset::Params params;
   std::unique_ptr<standalone::Dataset> dataset;
   TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
@@ -100,6 +101,7 @@
   task.id = task_def.task_id();
   task.dataset = std::move(dataset);
   task.iterator = std::move(iterator);
+  VLOG(3) << "Began processing for task " << task_def.task_id();
   return Status::OK();
 }
 
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index 40a6d53..361f7ed 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -16,6 +16,7 @@
 #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
 
 #include <vector>
+
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -90,7 +91,7 @@
         ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \
             op, i, "e")                                                     \
             .error_message();                                               \
-    const std::string& substring = error_substring;                         \
+    const std::string substring = error_substring;                          \
     EXPECT_NE("", error_message);                                           \
     EXPECT_TRUE(absl::StrContains(error_message, substring))                \
         << "Expected to see '" << substring << "' in '" << error_message    \
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index f47265f..cd0d44e 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -114,12 +114,12 @@
 }
 
 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
-  const uint64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
+  const uint64 kTwentyMinutesInUsec = 20 * 60 * 1000 * 1000;
   if (cfg.meta_optimizer_timeout_ms() < 0) {
     return 0;
   } else {
     return cfg.meta_optimizer_timeout_ms() == 0
-               ? Env::Default()->NowMicros() + kFiveMinutesInUsec
+               ? Env::Default()->NowMicros() + kTwentyMinutesInUsec
                : Env::Default()->NowMicros() +
                      cfg.meta_optimizer_timeout_ms() * 1000;
   }
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index e47c681..7cfb6fc 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -7096,7 +7096,7 @@
 
 build_test(
     name = "android_tensorflow_kernels_build_test",
-    targets = [":android_tensorflow_kernels"],
+    targets = [":portable_tensorflow_kernels"],
 )
 
 cc_library(
@@ -7109,7 +7109,7 @@
         "//tensorflow/core:android_gif_internal",
         "//tensorflow/core:android_jpeg_internal",
         "//tensorflow/core:android_png_internal",
-        "//tensorflow/core:android_tensorflow_lib_lite",
+        "//tensorflow/core:portable_tensorflow_lib_lite",
     ],
     alwayslink = 1,
 )
@@ -7126,7 +7126,7 @@
     linkopts = ["-ldl"],
     visibility = ["//visibility:public"],
     deps = [
-        "//tensorflow/core:android_tensorflow_lib_lite",
+        "//tensorflow/core:portable_tensorflow_lib_lite",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index d61c574..4ddfd99 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -138,6 +138,7 @@
         "//tensorflow/core/kernels/data:dataset_utils",
         "//tensorflow/core/kernels/data:name_utils",
         "//tensorflow/core/kernels/data:serialization_utils",
+        "//tensorflow/core/profiler/lib:traceme",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/strings",
diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
index 8c33668..697f4d9 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
@@ -37,6 +37,7 @@
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/snappy.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
 
 namespace tensorflow {
@@ -178,7 +179,10 @@
   class Iterator : public DatasetIterator<Dataset> {
    public:
     explicit Iterator(const Params& params, int64 iterator_index)
-        : DatasetIterator<Dataset>(params), iterator_index_(iterator_index) {}
+        : DatasetIterator<Dataset>(params),
+          iterator_index_(iterator_index),
+          max_outstanding_requests_(params.dataset->max_outstanding_requests_) {
+    }
 
     ~Iterator() override {
       mutex_lock l(mu_);
@@ -390,21 +394,23 @@
         // TODO(aaudibert): add backoff and max retries.
         int64 deadline_micros =
             Env::Default()->NowMicros() + kRetryTimeoutMicros;
-        Status s = FetchElement(task_thread, deadline_micros);
+        Status s = GetElement(task_thread, deadline_micros);
         if (!s.ok()) {
-          LOG(WARNING) << "Failed to fetch element from worker at "
+          LOG(WARNING) << "Failed to get element from worker at "
                        << task_thread->address << ": " << s;
         }
       }
     }
 
-    // Fetches an element from a task and adds the element to `results_`.
+    // Gets an element from a task and adds the element to `results_`.
     //
     // If the task reaches end_of_sequence or is cancelled (e.g. due to a
-    // worker dying), FetchElement returns Status::OK() without adding to
+    // worker dying), GetElement returns Status::OK() without adding to
     // `results_`.
-    Status FetchElement(TaskThread* task_thread, int64 deadline_micros) {
-      VLOG(3) << "Fetching an element for task id " << task_thread->task_id;
+    Status GetElement(TaskThread* task_thread, int64 deadline_micros) {
+      VLOG(3) << "Getting an element for task id " << task_thread->task_id;
+      tensorflow::profiler::TraceMe activity(
+          "GetElement", tensorflow::profiler::TraceMeLevel::kInfo);
       CompressedElement compressed;
       bool end_of_sequence;
       for (int num_retries = 0;; ++num_retries) {
@@ -453,7 +459,7 @@
       }
       results_.push(std::move(element));
       cv_.notify_all();
-      VLOG(3) << "Fetched an element for task id " << task_thread->task_id;
+      VLOG(3) << "Got an element for task id " << task_thread->task_id;
       return Status::OK();
     }
 
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 5dae096..7b8f697 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -621,6 +621,11 @@
       return false;
     }
     if (!deterministic_) {
+      // Iterate through in-flight results and returns the first one that is
+      // found to be available and not end-of-input. If the first result (in
+      // order) is end-of-input, we know that all earlier iterations have
+      // already been completed, so it is safe to return that result for the
+      // caller to process end of iteration.
       for (auto it = invocation_results_.begin();
            it != invocation_results_.end(); ++it) {
         if ((*it)->notification.HasBeenNotified() &&
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 0035677..42364e4 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -435,9 +435,9 @@
     for (const string& dump_root : dump_roots_) {
       tfdbg::DebugEventsWriter* debug_events_writer =
           tfdbg::DebugEventsWriter::GetDebugEventsWriter(dump_root);
-      debug_events_writer->WriteGraphExecutionTrace(
-          tfdbg_context_id_, device_name_, op_name_, output_slot_,
-          tensor_debug_mode_, tensor);
+      OP_REQUIRES_OK(context, debug_events_writer->WriteGraphExecutionTrace(
+                                  tfdbg_context_id_, device_name_, op_name_,
+                                  output_slot_, tensor_debug_mode_, tensor));
     }
     context->set_output(0, tensor);
   }
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 0f5a701..3b38daf 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -61,7 +61,9 @@
                                 " is '" +
                                 DataTypeString(ctx->output_type(0)) + "'"));
 
+    need_cast_ = true;
     if (ctx->output_type(0) == DT_FLOAT) {
+      need_cast_ = false;
       OP_REQUIRES(ctx,
                   (mode_string == "MIN_COMBINED" ||
                    mode_string == "MIN_FIRST" || mode_string == "SCALED"),
@@ -98,8 +100,9 @@
     }
 
     Tensor* output = nullptr;
-    Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
+    Tensor float_output =
+        need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output;
     if (num_slices == 1) {
       const float min_range = input_min_tensor.flat<float>()(0);
       const float max_range = input_max_tensor.flat<float>()(0);
@@ -128,10 +131,12 @@
                         max_ranges(i), output_tensor.template chip<1>(i));
       }
     }
-    S* out_ptr = output->flat<S>().data();
-    float* in_ptr = float_output.flat<float>().data();
-    for (int64 i = 0; i < float_output.NumElements(); ++i) {
-      out_ptr[i] = static_cast<S>(in_ptr[i]);
+    if (need_cast_) {
+      S* out_ptr = output->flat<S>().data();
+      float* in_ptr = float_output.flat<float>().data();
+      for (int64 i = 0; i < float_output.NumElements(); ++i) {
+        out_ptr[i] = static_cast<S>(in_ptr[i]);
+      }
     }
   }
 
@@ -219,6 +224,7 @@
   int mode_;
   int axis_;
   bool narrow_range_;
+  bool need_cast_;
 };
 
 REGISTER_KERNEL_BUILDER(Name("Dequantize")
diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD
index 80ad494..491e4c5 100644
--- a/tensorflow/core/lib/core/BUILD
+++ b/tensorflow/core/lib/core/BUILD
@@ -138,10 +138,13 @@
     cc_api_version = 2,
     make_default_target_header_only = True,
     protodeps = [
-        "//tensorflow/core:error_codes_proto_impl",
+        "//tensorflow/core/protobuf:error_codes_proto_impl",
     ],
-    visibility = ["//tensorflow/core:__subpackages__"],
-    exports = ["//tensorflow/core:error_codes_proto_impl"],
+    visibility = [
+        "//tensorflow/core:__subpackages__",
+        "//tensorflow/core/protobuf:__subpackages__",
+    ],
+    exports = ["//tensorflow/core/protobuf:error_codes_proto_impl"],
 )
 
 # Export source files needed for mobile builds, which do not use granular targets.
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 819f8fc..c7ff378 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -621,7 +621,7 @@
         ":stringpiece",
         ":stringprintf",
         ":types",
-        "//tensorflow/core:error_codes_proto_impl_cc",
+        "//tensorflow/core/protobuf:error_codes_proto_impl_cc",
         "@com_google_absl//absl/base",
     ],
 )
diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl
index f0613cd..ab45256 100644
--- a/tensorflow/core/platform/build_config.bzl
+++ b/tensorflow/core/platform/build_config.bzl
@@ -26,6 +26,7 @@
     _tf_platform_alias = "tf_platform_alias",
     _tf_platform_deps = "tf_platform_deps",
     _tf_portable_deps_no_runtime = "tf_portable_deps_no_runtime",
+    _tf_portable_proto_lib = "tf_portable_proto_lib",
     _tf_proto_library = "tf_proto_library",
     _tf_proto_library_cc = "tf_proto_library_cc",
     _tf_proto_library_py = "tf_proto_library_py",
@@ -65,6 +66,7 @@
 tf_logging_deps = _tf_logging_deps
 tf_platform_alias = _tf_platform_alias
 tf_platform_deps = _tf_platform_deps
+tf_portable_proto_lib = _tf_portable_proto_lib
 tf_portable_deps_no_runtime = _tf_portable_deps_no_runtime
 tf_proto_library = _tf_proto_library
 tf_proto_library_cc = _tf_proto_library_cc
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 18a8285..2dc4fdc 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -577,8 +577,8 @@
 
 def tf_protos_all_impl():
     return [
-        clean_dep("//tensorflow/core:autotuning_proto_cc_impl"),
-        clean_dep("//tensorflow/core:conv_autotuning_proto_cc_impl"),
+        clean_dep("//tensorflow/core/protobuf:autotuning_proto_cc_impl"),
+        clean_dep("//tensorflow/core/protobuf:conv_autotuning_proto_cc_impl"),
         clean_dep("//tensorflow/core:protos_all_cc_impl"),
     ]
 
@@ -727,6 +727,9 @@
         otherwise = [clean_dep("@com_google_protobuf//:protobuf_headers")],
     )
 
+def tf_portable_proto_lib():
+    return ["//tensorflow/core:protos_all_cc_impl"]
+
 def tf_protobuf_compiler_deps():
     return if_static(
         [
@@ -764,7 +767,7 @@
         "@nsync//:nsync_cpp",
         "@com_googlesource_code_re2//:re2",
         "@farmhash_archive//:farmhash",
-    ] + tf_protobuf_deps()
+    ]
 
 def tf_google_mobile_srcs_no_runtime():
     return []
diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.cc b/tensorflow/core/platform/profile_utils/cpu_utils.cc
index 587c978..b22123a 100644
--- a/tensorflow/core/platform/profile_utils/cpu_utils.cc
+++ b/tensorflow/core/platform/profile_utils/cpu_utils.cc
@@ -88,6 +88,8 @@
      defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__))
     retval = sscanf(line.c_str(), "clock              : %lfMHz", &cpu_freq);
     freq_factor = 1.0;
+#elif defined(__s390x__)
+    retval = sscanf(line.c_str(), "bogomips per cpu: %lf", &cpu_freq);
 #else
     retval = sscanf(line.c_str(), "bogomips : %lf", &cpu_freq);
 #endif
diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD
index 15a1ad0..369d26a 100644
--- a/tensorflow/core/profiler/convert/BUILD
+++ b/tensorflow/core/profiler/convert/BUILD
@@ -17,15 +17,18 @@
         "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:cost_utils",
-        "//tensorflow/core/profiler/utils:event_span",
+        "//tensorflow/core/profiler/utils:op_metrics_db_utils",
         "//tensorflow/core/profiler/utils:op_utils",
         "//tensorflow/core/profiler/utils:tf_op_utils",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
         "//tensorflow/core/profiler/utils:timespan",
         "//tensorflow/core/profiler/utils:trace_utils",
+        "//tensorflow/core/profiler/utils:xplane_schema",
+        "//tensorflow/core/profiler/utils:xplane_visitor",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -35,9 +38,11 @@
     srcs = ["xplane_to_op_metrics_db_test.cc"],
     deps = [
         ":xplane_to_op_metrics_db",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:op_metrics_db_utils",
         "//tensorflow/core/profiler/utils:time_utils",
         "//tensorflow/core/profiler/utils:xplane_builder",
@@ -86,13 +91,15 @@
         ":op_stats_to_input_pipeline_analysis",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
-        "//tensorflow/core/platform:logging",
         "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
         "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc",
         "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
         "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
         "//tensorflow/core/profiler/protobuf:overview_page_proto_cc",
+        "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
+        "//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
         "//tensorflow/core/profiler/utils:errors",
+        "//tensorflow/core/profiler/utils:html_utils",
         "//tensorflow/core/profiler/utils:math_utils",
         "//tensorflow/core/profiler/utils:op_metrics_db_utils",
         "//tensorflow/core/profiler/utils:time_utils",
@@ -118,11 +125,11 @@
         "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
         "//tensorflow/core/profiler/utils:errors",
         "//tensorflow/core/profiler/utils:event_span",
+        "//tensorflow/core/profiler/utils:html_utils",
         "//tensorflow/core/profiler/utils:math_utils",
         "//tensorflow/core/profiler/utils:tf_op_utils",
         "//tensorflow/core/profiler/utils:time_utils",
         "//tensorflow/core/util:stats_calculator_portable",
-        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
@@ -135,13 +142,12 @@
     hdrs = ["op_stats_to_tf_stats.h"],
     deps = [
         ":op_metrics_to_record",
+        "//tensorflow/core:lib",
         "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
         "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
         "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
         "//tensorflow/core/profiler/utils:op_metrics_db_utils",
-        "//tensorflow/core/profiler/utils:tf_op_utils",
         "//tensorflow/core/profiler/utils:time_utils",
-        "@com_google_absl//absl/container:flat_hash_set",
     ],
 )
 
@@ -152,13 +158,18 @@
     deps = [
         ":op_stats_to_tf_stats",
         ":xplane_to_op_stats",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+        "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
+        "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:op_metrics_db_utils",
         "//tensorflow/core/profiler/utils:time_utils",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "//tensorflow/core/profiler/utils:xplane_schema",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -171,6 +182,9 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
         "//tensorflow/core/profiler/utils:event_span",
+        "//tensorflow/core/profiler/utils:timespan",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:flat_hash_map",
     ],
 )
 
@@ -205,6 +219,7 @@
     srcs = ["xplane_to_op_stats.cc"],
     hdrs = ["xplane_to_op_stats.h"],
     deps = [
+        ":op_metrics_db_combiner",
         ":step_events_to_steps_db",
         ":xplane_to_kernel_stats_db",
         ":xplane_to_op_metrics_db",
@@ -213,15 +228,20 @@
         "//tensorflow/core:lib",
         "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
         "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc",
+        "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
         "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
+        "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
         "//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:event_span",
         "//tensorflow/core/profiler/utils:hardware_type_utils",
         "//tensorflow/core/profiler/utils:kernel_stats_utils",
+        "//tensorflow/core/profiler/utils:tf_op_utils",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
         "//tensorflow/core/profiler/utils:xplane_schema",
         "//tensorflow/core/profiler/utils:xplane_utils",
+        "//tensorflow/core/profiler/utils:xplane_visitor",
+        "@com_google_absl//absl/container:flat_hash_map",
     ],
 )
 
@@ -239,11 +259,15 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
         "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+        "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
         "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
+        "//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:group_events",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "//tensorflow/core/profiler/utils:xplane_schema",
         "//tensorflow/core/profiler/utils:xplane_utils",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -259,7 +283,6 @@
         ":xplane_to_memory_profile",
         ":xplane_to_op_stats",
         ":xplane_to_trace_events",
-        "//tensorflow/core:human_readable_json",
         "//tensorflow/core:lib",
         "//tensorflow/core/profiler:profiler_service_proto_cc",
         "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
@@ -269,6 +292,7 @@
         "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
         "//tensorflow/core/profiler/protobuf:overview_page_proto_cc",
         "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
+        "//tensorflow/core/profiler/protobuf:trace_events_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/rpc/client:save_profile",
         "//tensorflow/core/profiler/utils:xplane_schema",
@@ -284,12 +308,14 @@
     srcs = ["xplane_to_profile_response_test.cc"],
     deps = [
         ":xplane_to_profile_response",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/profiler:profiler_service_proto_cc",
         "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc",
         "//tensorflow/core/profiler/protobuf:overview_page_proto_cc",
         "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:group_events",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "//tensorflow/core/profiler/utils:xplane_schema",
@@ -303,14 +329,16 @@
     hdrs = ["xplane_to_step_events.h"],
     deps = [
         "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:event_span",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
+        "//tensorflow/core/profiler/utils:timespan",
         "//tensorflow/core/profiler/utils:trace_utils",
         "//tensorflow/core/profiler/utils:xplane_schema",
+        "//tensorflow/core/profiler/utils:xplane_visitor",
+        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -320,12 +348,16 @@
     srcs = ["xplane_to_step_events_test.cc"],
     deps = [
         ":xplane_to_step_events",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "//tensorflow/core/profiler/utils:event_span",
         "//tensorflow/core/profiler/utils:group_events",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "//tensorflow/core/profiler/utils:xplane_schema",
         "//tensorflow/core/profiler/utils:xplane_utils",
+        "@com_google_absl//absl/container:flat_hash_map",
     ],
 )
 
@@ -339,7 +371,9 @@
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
         "//tensorflow/core/profiler/utils:xplane_schema",
+        "//tensorflow/core/profiler/utils:xplane_visitor",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -355,6 +389,8 @@
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
+        "//tensorflow/core/profiler/protobuf:trace_events_proto_cc",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "//tensorflow/core/profiler/utils:xplane_schema",
         "//tensorflow/core/profiler/utils:xplane_utils",
@@ -370,14 +406,14 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
-        "//tensorflow/core/profiler/utils:event_span",
         "//tensorflow/core/profiler/utils:kernel_stats_utils",
         "//tensorflow/core/profiler/utils:tf_op_utils",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
         "//tensorflow/core/profiler/utils:trace_utils",
         "//tensorflow/core/profiler/utils:xplane_schema",
-        "//tensorflow/core/profiler/utils:xplane_utils",
         "//tensorflow/core/profiler/utils:xplane_visitor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -390,14 +426,14 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "//tensorflow/core/profiler/utils:math_utils",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
         "//tensorflow/core/profiler/utils:timespan",
         "//tensorflow/core/profiler/utils:xplane_schema",
-        "//tensorflow/core/profiler/utils:xplane_utils",
         "//tensorflow/core/profiler/utils:xplane_visitor",
         "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -414,10 +450,13 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
         "//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "//tensorflow/core/profiler/utils:xplane_schema",
         "//tensorflow/core/profiler/utils:xplane_utils",
+        "//tensorflow/core/profiler/utils:xplane_visitor",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -429,15 +468,18 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/framework:protos_all_cc",
         "//tensorflow/core/platform:protobuf",
         "//tensorflow/core/profiler/protobuf:memory_profile_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:tf_xplane_visitor",
         "//tensorflow/core/profiler/utils:xplane_schema",
-        "//tensorflow/core/profiler/utils:xplane_utils",
+        "//tensorflow/core/profiler/utils:xplane_visitor",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -447,10 +489,14 @@
     srcs = ["xplane_to_memory_profile_test.cc"],
     deps = [
         ":xplane_to_memory_profile",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "//tensorflow/core/profiler/protobuf:memory_profile_proto_cc",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "//tensorflow/core/profiler/utils:xplane_schema",
         "//tensorflow/core/profiler/utils:xplane_utils",
+        "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
index 3f601bb..8229d10 100644
--- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
+++ b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
 
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.cc b/tensorflow/core/profiler/convert/op_metrics_to_record.cc
index b51c679..8e28199 100644
--- a/tensorflow/core/profiler/convert/op_metrics_to_record.cc
+++ b/tensorflow/core/profiler/convert/op_metrics_to_record.cc
@@ -15,7 +15,9 @@
 
 #include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
 
+#include <iterator>
 #include <tuple>
+#include <vector>
 
 #include "absl/algorithm/container.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
index ca2a6c2..8367345 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
@@ -15,11 +15,12 @@
 
 #include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
 
+#include <math.h>
+
 #include <algorithm>
-#include <utility>
+#include <vector>
 
 #include "google/protobuf/any.pb.h"
-#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
@@ -27,7 +28,6 @@
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
 #include "tensorflow/core/profiler/convert/step_events_to_steps_db.h"
@@ -38,6 +38,7 @@
 #include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
 #include "tensorflow/core/profiler/utils/errors.h"
 #include "tensorflow/core/profiler/utils/event_span.h"
+#include "tensorflow/core/profiler/utils/html_utils.h"
 #include "tensorflow/core/profiler/utils/math_utils.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
 #include "tensorflow/core/profiler/utils/time_utils.h"
@@ -103,7 +104,7 @@
     avg = sdv = min = max = 0.0;
   } else {
     avg = sample_stats.avg();
-    sdv = std::sqrt(sample_stats.sample_variance());
+    sdv = sqrt(sample_stats.sample_variance());
     min = sample_stats.min();
     max = sample_stats.max();
   }
@@ -243,7 +244,7 @@
   kPreprocessing      // data preprocessing.
 };
 
-string InputOpCategoryString(InputOpCategory category) {
+std::string InputOpCategoryString(InputOpCategory category) {
   switch (category) {
     case InputOpCategory::kEnqueue:
       return "Enqueue";
@@ -327,10 +328,6 @@
   return details;
 }
 
-string AnchorElement(absl::string_view url, absl::string_view text) {
-  return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
-}
-
 // Returns the ratio of the host-to-device time in each step to the step-time.
 double RatioOfHostToDeviceTimeToStepTime(
     const OpMetricsDb& host_tf_metrics_db,
@@ -362,9 +359,9 @@
 }
 
 void KernelLaunchAnalysis(bool tfdata_used, double kernel_launch_percent,
-                          string* kernel_launch_classification,
-                          string* kernel_launch_statement) {
-  string percent_str = absl::StrFormat("%.1lf", kernel_launch_percent);
+                          std::string* kernel_launch_classification,
+                          std::string* kernel_launch_statement) {
+  std::string percent_str = absl::StrFormat("%.1lf", kernel_launch_percent);
   if (kernel_launch_percent >= kHighlyKernelLaunchBoundThresholdInPercent) {
     *kernel_launch_classification = "high";
     *kernel_launch_statement = absl::StrCat(
@@ -389,14 +386,14 @@
 }
 
 void AllOtherAnalysis(bool all_other_reported, double all_other_percent,
-                      string* all_other_classification,
-                      string* all_other_statement) {
+                      std::string* all_other_classification,
+                      std::string* all_other_statement) {
   if (all_other_reported) {
     *all_other_classification = "no";
     *all_other_statement = "";
     return;
   }
-  string percent_str = absl::StrFormat("%.1lf", all_other_percent);
+  std::string percent_str = absl::StrFormat("%.1lf", all_other_percent);
   if (all_other_percent >= kHighlyAllOtherBoundThresholdInPercent) {
     *all_other_classification = "high";
     *all_other_statement =
@@ -588,9 +585,10 @@
 }
 
 bool InputAnalysis(double input_percent, double all_other_percent,
-                   string* input_classification, string* input_statement) {
+                   std::string* input_classification,
+                   std::string* input_statement) {
   absl::string_view non_input_time = "other time";
-  string infeed_percent_str = absl::StrFormat("%.1lf", input_percent);
+  std::string infeed_percent_str = absl::StrFormat("%.1lf", input_percent);
   if (input_percent >= kHighlyInfeedBoundThresholdInPercent) {
     *input_classification = "host";
     *input_statement = absl::StrCat(
@@ -610,7 +608,8 @@
     // Input analysis says it is not input-bound, but "All-Other" time
     // is significant. It could still be input-bound (or Python overhead).
     *input_classification = "both";
-    string all_other_percent_str = absl::StrFormat("%.1lf", all_other_percent);
+    std::string all_other_percent_str =
+        absl::StrFormat("%.1lf", all_other_percent);
     *input_statement = absl::StrCat(
         "Your program is POTENTIALLY input-bound because ",
         all_other_percent_str,
@@ -630,8 +629,8 @@
   }
 }
 
-void OutputAnalysis(double output_percent, string* output_classification,
-                    string* output_statement) {
+void OutputAnalysis(double output_percent, std::string* output_classification,
+                    std::string* output_statement) {
   string tc_outfeed_percent_str = absl::StrFormat("%.1lf", output_percent);
   if (output_percent >= kHighlyOutfeedBoundThresholdInPercent) {
     *output_classification = "host";
@@ -703,19 +702,19 @@
   double kernel_launch_percent =
       100.0 * total_host_prepare_ms / total_step_time_ms;
   double all_other_percent = 100.0 * total_unknown_ms / total_step_time_ms;
-  string input_classification;
-  string input_statement;
+  std::string input_classification;
+  std::string input_statement;
   bool all_other_reported =
       InputAnalysis(input_percent, all_other_percent, &input_classification,
                     &input_statement);
 
-  string kernel_launch_classification;
-  string kernel_launch_statement;
+  std::string kernel_launch_classification;
+  std::string kernel_launch_statement;
   KernelLaunchAnalysis(TfDataInUse(input_time_breakdown), kernel_launch_percent,
                        &kernel_launch_classification, &kernel_launch_statement);
 
-  string all_other_classification;
-  string all_other_statement;
+  std::string all_other_classification;
+  std::string all_other_statement;
   AllOtherAnalysis(all_other_reported, all_other_percent,
                    &all_other_classification, &all_other_statement);
 
@@ -729,9 +728,9 @@
   return analysis;
 }
 
-string GetSummaryNextStep(absl::string_view input_classification,
-                          const InputTimeBreakdown& breakdown) {
-  string summary_next_step;
+std::string GetSummaryNextStep(absl::string_view input_classification,
+                               const InputTimeBreakdown& breakdown) {
+  std::string summary_next_step;
   if (input_classification == "host" || input_classification == "both") {
     if (!TfDataInUse(breakdown)) {
       summary_next_step = absl::StrCat(
diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
index 738daea..93b4df0 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
+++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
@@ -16,12 +16,15 @@
 #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
 #define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
 
+#include <string>
+
 #include "google/protobuf/any.pb.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
 #include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
 
@@ -50,16 +53,18 @@
 // Returns true if explanation for "All Others" time is also included in
 // input_statement.
 bool InputAnalysis(double input_percent, double all_other_percent,
-                   string* input_classification, string* input_statement);
+                   std::string* input_classification,
+                   std::string* input_statement);
 
-void OutputAnalysis(double output_percent, string* output_classification,
-                    string* output_statement);
+void OutputAnalysis(double output_percent, std::string* output_classification,
+                    std::string* output_statement);
 
 string GetSummaryNextStep(absl::string_view input_classification,
                           const InputTimeBreakdown& breakdown);
 
 void AddErrorMessages(const OpStats& op_stats,
                       InputPipelineAnalysisResult* result);
+
 }  // namespace profiler
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
index e19690a..bec92e0 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
@@ -15,13 +15,11 @@
 
 #include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h"
 
-#include <algorithm>
-#include <utility>
+#include <string>
 
 #include "google/protobuf/any.pb.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
 #include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
@@ -30,7 +28,10 @@
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/overview_page.pb.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+#include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
 #include "tensorflow/core/profiler/utils/errors.h"
+#include "tensorflow/core/profiler/utils/html_utils.h"
 #include "tensorflow/core/profiler/utils/math_utils.h"
 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
 #include "tensorflow/core/profiler/utils/time_utils.h"
@@ -44,24 +45,23 @@
 // statement of suggestion will be made.
 constexpr double kLowPrecisionPercentThreshold = 10;
 
-OverviewPageTip MakeOverviewPageTip(const string& text) {
-  OverviewPageTip tip;
-  tip.set_link(text);
-  return tip;
-}
+struct TfFunctionInfo {
+  absl::string_view function_name;
+  double expensive_call_percent;
+};
 
-string AnchorElement(const string& url, const string& text) {
-  return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
+OverviewPageTip MakeOverviewPageTip(std::string text) {
+  OverviewPageTip tip;
+  tip.set_link(std::move(text));
+  return tip;
 }
 
 // Makes a recommendation for looking up a document.
 // doc_url is expected to be already be escaped suitably for use in an HTML
 // attribute.
-OverviewPageTip MakeOverviewPageTipDocLink(const string& doc_url,
-                                           const string& text) {
-  OverviewPageTip tip;
-  tip.set_link(AnchorElement(doc_url, text));
-  return tip;
+OverviewPageTip MakeOverviewPageTipDocLink(absl::string_view doc_url,
+                                           absl::string_view text) {
+  return MakeOverviewPageTip(AnchorElement(doc_url, text));
 }
 
 void ComputeHostTips(OverviewPageRecommendation* re) {
@@ -75,12 +75,13 @@
 
 void ComputeDeviceTips(HardwareType hardware_type,
                        OverviewPageRecommendation* re) {
-  const string& device_name = HardwareType_Name(hardware_type);
-  string timeline_name =
-      (hardware_type == tensorflow::profiler::TPU) ? "TPU core" : device_name;
-  string op_stats_toolname = (hardware_type == tensorflow::profiler::TPU)
-                                 ? "op_profile"
-                                 : "tensorflow_stats";
+  absl::string_view device_name = HardwareType_Name(hardware_type);
+  absl::string_view timeline_name = device_name;
+  absl::string_view op_stats_toolname = "tensorflow_stats";
+  if (hardware_type == tensorflow::profiler::TPU) {
+    timeline_name = "TPU core";
+    op_stats_toolname = "op_profile";
+  }
   *re->add_device_tips() = MakeOverviewPageTip(
       absl::StrCat(op_stats_toolname,
                    " (identify the time-consuming operations "
@@ -121,14 +122,16 @@
 
 }  // namespace
 
-void SetCommonRecommendation(const string& input_classification,
-                             const string& input_statement,
-                             const string& output_statement,
+void SetCommonRecommendation(absl::string_view input_classification,
+                             absl::string_view input_statement,
+                             absl::string_view output_statement,
                              HardwareType hardware_type,
+                             absl::string_view tf_function_statement_html,
                              OverviewPageRecommendation* re) {
-  re->set_bottleneck(input_classification);
-  re->set_statement(input_statement);
-  re->set_output_statement(output_statement);
+  re->set_bottleneck(std::string(input_classification));
+  re->set_statement(std::string(input_statement));
+  re->set_output_statement(std::string(output_statement));
+  re->set_tf_function_statement_html(std::string(tf_function_statement_html));
   ComputeHostTips(re);
   ComputeDeviceTips(hardware_type, re);
   ComputeDocumentationTips(re);
@@ -245,6 +248,35 @@
   return re;
 }
 
+std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db) {
+  std::vector<TfFunctionInfo> candidates;
+  for (const auto& name_fun : tf_function_db.tf_functions()) {
+    const auto& fun = name_fun.second;
+    if (fun.expensive_call_percent() >= kTfFunctionReportThresholdInPercent) {
+      candidates.push_back({name_fun.first, fun.expensive_call_percent()});
+    }
+  }
+  if (candidates.empty()) return "";
+  auto cmp = [](const TfFunctionInfo& a, const TfFunctionInfo& b) {
+    return a.expensive_call_percent > b.expensive_call_percent;
+  };
+  // Sorts candidates in descending order of expensive_call_percent.
+  absl::c_sort(candidates, cmp);
+  std::string expensive_functions = "";
+  auto num_functions_shown = std::min(
+      static_cast<decltype(candidates)::size_type>(3), candidates.size());
+
+  for (auto i = 0; i < num_functions_shown; i++) {
+    if (i > 0) absl::StrAppend(&expensive_functions, ", ");
+    absl::StrAppend(&expensive_functions, "\"", candidates[i].function_name,
+                    "\"");
+  }
+  if (candidates.size() > num_functions_shown)
+    absl::StrAppend(&expensive_functions, " and more");
+  return absl::StrCat("Expensive tf-functions detected (", expensive_functions,
+                      ") due to either retracing or eager execution.");
+}
+
 OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats,
                                           HardwareType hardware_type) {
   OverviewPage overview_page;
@@ -258,9 +290,10 @@
       overview_page.input_analysis().step_details());
   *overview_page.mutable_recommendation() = ComputeGenericRecommendation(
       bottleneck, op_stats.device_op_metrics_db().precision_stats());
-  SetCommonRecommendation(bottleneck.input_classification(),
-                          bottleneck.input_statement(), "", hardware_type,
-                          overview_page.mutable_recommendation());
+  SetCommonRecommendation(
+      bottleneck.input_classification(), bottleneck.input_statement(), "",
+      hardware_type, TfFunctionRecommendationHtml(op_stats.tf_function_db()),
+      overview_page.mutable_recommendation());
   return overview_page;
 }
 
diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h
index e6d1270..b4b3991 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h
+++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h
@@ -17,9 +17,7 @@
 #define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_
 
 #include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
 #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
 #include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
@@ -29,10 +27,16 @@
 namespace tensorflow {
 namespace profiler {
 
-void SetCommonRecommendation(const string& input_classification,
-                             const string& input_statement,
-                             const string& output_statement,
+// Reports tf-function optimization opportunity in the Overview Page if the
+// expensive-call-time percentage is over this threshold for at least one of
+// the tf-functions profiled.
+const double kTfFunctionReportThresholdInPercent = 20;
+
+void SetCommonRecommendation(absl::string_view input_classification,
+                             absl::string_view input_statement,
+                             absl::string_view output_statement,
                              HardwareType hardware_type,
+                             absl::string_view tf_function_statement_html,
                              OverviewPageRecommendation* re);
 
 OverviewPageRecommendation ComputeGenericRecommendation(
@@ -47,6 +51,9 @@
 OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats,
                                           HardwareType hardware_type);
 
+// Returns a html which provides tf-function related recommendation.
+std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db);
+
 void SetRemarks(const OpStats& op_stats, OverviewPageAnalysis* analysis);
 
 }  // namespace profiler
diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
index da409f8..e23813a 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
@@ -15,13 +15,12 @@
 
 #include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h"
 
-#include "absl/container/flat_hash_set.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
-#include "tensorflow/core/profiler/utils/tf_op_utils.h"
 #include "tensorflow/core/profiler/utils/time_utils.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
index 3e098da..9ca83b5 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
@@ -15,10 +15,14 @@
 
 #include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h"
 
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/convert/xplane_to_op_stats.h"
-#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
-#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/time_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
@@ -75,8 +79,8 @@
                        kKernel3DurationNs, /*on_device=*/true, kKernel3,
                        &device_plane, &stream2);
 
-  const OpStats& op_stats = ConvertXSpaceToOpStats(space);
-  const TfStatsDatabase& tf_stats = ConvertOpStatsToTfStats(op_stats);
+  const OpStats op_stats = ConvertXSpaceToOpStats(space);
+  const TfStatsDatabase tf_stats = ConvertOpStatsToTfStats(op_stats);
 
   // TfOp1, TfOp2, Idle
   EXPECT_EQ(3, tf_stats.with_idle().tf_stats_record_size());
diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc
index ed0d83a..e4713cd 100644
--- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc
+++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc
@@ -15,10 +15,18 @@
 #include "tensorflow/core/profiler/convert/step_events_to_steps_db.h"
 
 #include <sstream>
+#include <utility>
+#include <vector>
 
 #include "google/protobuf/any.pb.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+#include "tensorflow/core/profiler/utils/event_span.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/tensorflow/core/profiler/convert/step_events_to_steps_db.h
index b3ea74e..9db6516 100644
--- a/tensorflow/core/profiler/convert/step_events_to_steps_db.h
+++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_
 #define TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_
 
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
 #include "tensorflow/core/profiler/utils/event_span.h"
 
diff --git a/tensorflow/core/profiler/convert/trace_events_to_json.cc b/tensorflow/core/profiler/convert/trace_events_to_json.cc
index 9c8176c..07e32ce 100644
--- a/tensorflow/core/profiler/convert/trace_events_to_json.cc
+++ b/tensorflow/core/profiler/convert/trace_events_to_json.cc
@@ -15,9 +15,14 @@
 
 #include "tensorflow/core/profiler/convert/trace_events_to_json.h"
 
+#include <algorithm>
+#include <map>
+#include <utility>
+
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "include/json/json.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc
index 785902e..023d6a7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc
@@ -15,16 +15,20 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h"
 
+#include <functional>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
-#include "tensorflow/core/profiler/utils/event_span.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/trace_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
-#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h
index 04bd0e8..9c7fca2 100644
--- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h
+++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h
@@ -17,9 +17,7 @@
 #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_
 
 #include <functional>
-#include <vector>
 
-#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
@@ -31,6 +29,7 @@
     const XPlane& device_trace,
     const std::function<void(const XEventVisitor&, KernelReport*)>&
         on_kernel_fn);
+
 }  // namespace profiler
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
index 1695bd3..5b2a748 100644
--- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
@@ -15,21 +15,28 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h"
 
-#include <cstddef>
+#include <algorithm>
 #include <string>
 #include <tuple>
+#include <type_traits>
 #include <utility>
+#include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/memory_profile.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
-#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc
index 1173e4d..e0d87ac 100644
--- a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc
@@ -15,7 +15,11 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h"
 
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
index 09df59e..4a369b8 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
@@ -15,21 +15,31 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
 
+#include <algorithm>
+#include <memory>
 #include <vector>
 
 #include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
 #include "tensorflow/core/profiler/convert/op_stack.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/cost_utils.h"
+#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
 #include "tensorflow/core/profiler/utils/op_utils.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
+#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/timespan.h"
 #include "tensorflow/core/profiler/utils/trace_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
index 1a785d0..f2d7fc7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
@@ -21,10 +21,9 @@
 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
-#include "tensorflow/core/profiler/utils/event_span.h"
 #include "tensorflow/core/profiler/utils/op_utils.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
-#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
index 3e577d0..8bd0443 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
@@ -15,9 +15,12 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
 
+#include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
 #include "tensorflow/core/profiler/utils/time_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
index 7fdd6ff..f008219 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
@@ -15,7 +15,11 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_op_stats.h"
 
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
 #include "tensorflow/core/profiler/convert/step_events_to_steps_db.h"
 #include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h"
 #include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
@@ -23,12 +27,19 @@
 #include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
 #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
 #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
 #include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/event_span.h"
 #include "tensorflow/core/profiler/utils/hardware_type_utils.h"
 #include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
+#include "tensorflow/core/profiler/utils/tf_op_utils.h"
+#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
index c7b140b..7b4652f 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
@@ -15,10 +15,14 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_op_stats.h"
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+#include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/group_events.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc
index 74dd343..e6fe749 100644
--- a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc
@@ -15,8 +15,10 @@
 #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
 
 #include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/human_readable_json.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
 #include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h"
@@ -33,6 +35,7 @@
 #include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
 #include "tensorflow/core/profiler/protobuf/overview_page.pb.h"
 #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/rpc/client/save_profile.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
@@ -65,20 +68,26 @@
 }
 
 template <typename Proto>
-Status AddJsonToolData(absl::string_view tool_name, const Proto& tool_output,
-                       ProfileResponse* response) {
-  std::string json_output;
-  TF_RETURN_IF_ERROR(ProtoToHumanReadableJson(tool_output, &json_output,
-                                              /*ignore_accuracy_loss=*/true));
-  auto* tool_data = response->add_tool_data();
-  tool_data->set_name(string(tool_name));
-  tool_data->mutable_data()->append(json_output.data(), json_output.size());
+Status ConvertProtoToJson(const Proto& proto_output, std::string* json_output) {
+  protobuf::util::JsonPrintOptions json_options;
+  json_options.always_print_primitive_fields = true;
+  auto status = protobuf::util::MessageToJsonString(proto_output, json_output,
+                                                    json_options);
+  if (!status.ok()) {
+    // Convert error_msg google::protobuf::StringPiece (or absl::string_view) to
+    // tensorflow::StringPiece.
+    auto error_msg = status.message();
+    return errors::Internal(
+        strings::StrCat("Could not convert proto to JSON string: ",
+                        StringPiece(error_msg.data(), error_msg.length())));
+  }
   return Status::OK();
 }
 
 // Returns the tool name with extension.
 string ToolName(absl::string_view tool) {
   if (tool == kTraceViewer) return "trace.json.gz";
+  if (tool == kMemoryProfile) return "memory_profile.json.gz";
   return absl::StrCat(tool, ".pb");
 }
 
@@ -130,8 +139,11 @@
   if (tools.contains(kMemoryProfile)) {
     if (const XPlane* host_plane = FindPlaneWithName(xspace, kHostThreads)) {
       MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane);
-      TF_RETURN_IF_ERROR(
-          AddJsonToolData(ToolName(kMemoryProfile), memory_profile, response));
+      std::string json_output;
+      TF_RETURN_IF_ERROR(ConvertProtoToJson(memory_profile, &json_output));
+      TF_RETURN_IF_ERROR(SaveGzippedToolDataToTensorboardProfile(
+          req.repository_root(), req.session_id(), req.host_name(),
+          ToolName(kMemoryProfile), json_output));
     }
   }
   return Status::OK();
diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response.h b/tensorflow/core/profiler/convert/xplane_to_profile_response.h
index 84b9fdd..03ca13f 100644
--- a/tensorflow/core/profiler/convert/xplane_to_profile_response.h
+++ b/tensorflow/core/profiler/convert/xplane_to_profile_response.h
@@ -15,8 +15,6 @@
 #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_PROFILE_RESPONSE_H_
 #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_PROFILE_RESPONSE_H_
 
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/profiler/profiler_service.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc b/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc
index d4965a9..ad9ca10 100644
--- a/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc
@@ -14,13 +14,14 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
 
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/profiler/profiler_service.pb.h"
 #include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
 #include "tensorflow/core/profiler/protobuf/overview_page.pb.h"
 #include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
-#include "tensorflow/core/profiler/utils/xplane_schema.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc
index c7dcd62..7bb7cd6 100644
--- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc
@@ -15,10 +15,18 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_step_events.h"
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/event_span.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
 #include "tensorflow/core/profiler/utils/trace_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.h b/tensorflow/core/profiler/convert/xplane_to_step_events.h
index a7ac3b9..62fc898 100644
--- a/tensorflow/core/profiler/convert/xplane_to_step_events.h
+++ b/tensorflow/core/profiler/convert/xplane_to_step_events.h
@@ -18,7 +18,7 @@
 
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/event_span.h"
-#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc
index 3e1610c..36e6a2c 100644
--- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc
@@ -15,7 +15,13 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_step_events.h"
 
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/event_span.h"
 #include "tensorflow/core/profiler/utils/group_events.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc
index f768d3b..b25cdc4 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc
@@ -15,20 +15,25 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
 
+#include <algorithm>
 #include <stack>
+#include <string>
+#include <utility>
+#include <vector>
 
 #include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/math_utils.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/timespan.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
-#include "tensorflow/core/profiler/utils/xplane_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
@@ -54,6 +59,21 @@
   DCHECK(false);
 }
 
+double ComputeExpensiveCallPercent(const TfFunction& tf_function) {
+  // Computes the expensiveness in terms of time (rather than count).
+  uint64 total_call_time_ps = 0;
+  uint64 expensive_call_time_ps = 0;
+  for (const auto& mode_metrics : tf_function.metrics()) {
+    const auto mode = mode_metrics.first;
+    const auto& metrics = mode_metrics.second;
+    total_call_time_ps += metrics.self_time_ps();
+    if (mode == TRACED_MODE || mode == EAGER_MODE) {
+      expensive_call_time_ps += metrics.self_time_ps();
+    }
+  }
+  return SafeDivide(100.0 * expensive_call_time_ps, total_call_time_ps);
+}
+
 // Each invocation of a tf-function creates an ActivationRecord.
 struct ActivationRecord {
   std::string function_name;               // name of the tf-function.
@@ -133,6 +153,7 @@
       CombineTfFunctionMetrics(src_metrics, dst_metrics);
     }
   }
+  dst->set_expensive_call_percent(ComputeExpensiveCallPercent(*dst));
 }
 
 // Execution history of all tf-functions invoked.
@@ -210,6 +231,10 @@
       metrics->set_count(metrics->count() + 1);
       metrics->set_self_time_ps(metrics->self_time_ps() + self_time_ps);
     }
+    for (auto& name_fun : *result.mutable_tf_functions()) {
+      TfFunction& fun = name_fun.second;
+      fun.set_expensive_call_percent(ComputeExpensiveCallPercent(fun));
+    }
     return result;
   }
 
diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h
index 470b22d..df55ac7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h
+++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h
@@ -16,8 +16,9 @@
 #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_
 #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_
 
+#include <string>
+
 #include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc
index 253ef1a..25e56d1 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc
@@ -15,12 +15,17 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
 
+#include <string>
+
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
@@ -33,6 +38,8 @@
 const absl::string_view kNotTracedNonXla = "notTraced-nonXla";
 const absl::string_view kNotTracedXla = "notTraced-xla";
 
+constexpr double kMaxError = 0.001;
+
 TfFunctionDb ConvertXSpaceToTfFunctionDb(const XSpace& space) {
   TfFunctionDb result;
   const XPlane* host_plane = FindPlaneWithName(space, kHostThreads);
@@ -75,6 +82,8 @@
       tf_function_db.tf_functions().at(kFunctionName);
   EXPECT_EQ(tf_function.total_tracing_count(), 4);
   EXPECT_EQ(tf_function.compiler(), MIXED_COMPILER);
+  EXPECT_NEAR(tf_function.expensive_call_percent(), 90, kMaxError);
+
   const auto& metrics = tf_function.metrics();
   EXPECT_EQ(metrics.size(), 2);
   EXPECT_EQ(metrics.count(TRACED_MODE), 1);
@@ -108,6 +117,7 @@
       tf_function_db.tf_functions().at(kOuterFunctionName);
   EXPECT_EQ(outer.total_tracing_count(), 1);
   EXPECT_EQ(outer.compiler(), OTHER_COMPILER);
+  EXPECT_NEAR(outer.expensive_call_percent(), 100, kMaxError);
   const auto& outer_metrics = outer.metrics();
   EXPECT_EQ(outer_metrics.size(), 1);
   EXPECT_EQ(outer_metrics.count(TRACED_MODE), 1);
@@ -118,6 +128,7 @@
       tf_function_db.tf_functions().at(kInnerFunctionName);
   EXPECT_EQ(inner.total_tracing_count(), 0);
   EXPECT_EQ(inner.compiler(), XLA_COMPILER);
+  EXPECT_NEAR(inner.expensive_call_percent(), 0, kMaxError);
   const auto& inner_metrics = inner.metrics();
   EXPECT_EQ(inner_metrics.size(), 1);
   EXPECT_EQ(inner_metrics.count(NOT_TRACED_MODE), 1);
@@ -148,6 +159,7 @@
       tf_function_db.tf_functions().at(kEagerFunctionName);
   EXPECT_EQ(eager.total_tracing_count(), 0);
   EXPECT_EQ(eager.compiler(), INVALID_COMPILER);
+  EXPECT_NEAR(eager.expensive_call_percent(), 100, kMaxError);
   const auto& eager_metrics = eager.metrics();
   EXPECT_EQ(eager_metrics.size(), 1);
   EXPECT_EQ(eager_metrics.count(EAGER_MODE), 1);
@@ -158,6 +170,7 @@
       tf_function_db.tf_functions().at(kConcreteFunctionName);
   EXPECT_EQ(concrete.total_tracing_count(), 0);
   EXPECT_EQ(concrete.compiler(), INVALID_COMPILER);
+  EXPECT_NEAR(concrete.expensive_call_percent(), 0, kMaxError);
   const auto& concrete_metrics = concrete.metrics();
   EXPECT_EQ(concrete_metrics.size(), 1);
   EXPECT_EQ(concrete_metrics.count(CONCRETE_MODE), 1);
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
index 901f3be..c404f7b 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
@@ -15,8 +15,21 @@
 
 #include "tensorflow/core/profiler/convert/xplane_to_trace_events.h"
 
+#include <stddef.h>
+
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.h b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
index 5c6fbea..b7bddb7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.h
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
@@ -16,7 +16,8 @@
 #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_EVENTS_H_
 #define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_EVENTS_H_
 
-#include "absl/strings/str_split.h"
+#include <string>
+
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
index afff5e6..b9a9fe0 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
@@ -16,8 +16,9 @@
 #include "tensorflow/core/profiler/convert/xplane_to_trace_events.h"
 
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
-#include "tensorflow/core/profiler/utils/xplane_schema.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD
index 9fab42c..85fa4e7 100644
--- a/tensorflow/core/profiler/internal/BUILD
+++ b/tensorflow/core/profiler/internal/BUILD
@@ -423,8 +423,10 @@
     deps = [
         ":traceme_recorder",
         "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
         "@com_google_absl//absl/strings",
-        "@com_google_googletest//:gtest_main",
+        "@com_google_googletest//:gtest",
     ],
 )
 
@@ -434,7 +436,6 @@
     deps = [
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/profiler:profiler_options_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
     ],
 )
@@ -444,6 +445,7 @@
     hdrs = ["profiler_factory.h"],
     deps = [
         ":profiler_interface",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
     ] + if_static([
         ":profiler_factory_impl",
     ]),
@@ -461,8 +463,7 @@
     deps = [
         ":profiler_interface",
         "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
     ],
     alwayslink = True,
 )
@@ -513,15 +514,10 @@
     srcs = ["scoped_annotation_test.cc"],
     deps = [
         ":annotation_stack",
-        "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
-        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
         "//tensorflow/core/profiler/lib:scoped_annotation",
         "@com_google_absl//absl/strings",
     ],
@@ -544,6 +540,6 @@
         ":parse_annotation",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
+        "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/core/profiler/internal/annotation_stack.cc b/tensorflow/core/profiler/internal/annotation_stack.cc
index 4cfd102..4c15ca4 100644
--- a/tensorflow/core/profiler/internal/annotation_stack.cc
+++ b/tensorflow/core/profiler/internal/annotation_stack.cc
@@ -15,6 +15,10 @@
 
 #include "tensorflow/core/profiler/internal/annotation_stack.h"
 
+#include <atomic>
+
+#include "tensorflow/core/platform/types.h"
+
 namespace tensorflow {
 namespace profiler {
 namespace internal {
diff --git a/tensorflow/core/profiler/internal/annotation_stack.h b/tensorflow/core/profiler/internal/annotation_stack.h
index 38cd962..e626c4c 100644
--- a/tensorflow/core/profiler/internal/annotation_stack.h
+++ b/tensorflow/core/profiler/internal/annotation_stack.h
@@ -18,6 +18,7 @@
 #include <stddef.h>
 
 #include <atomic>
+#include <utility>
 
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD
index e156667..c24c8c7 100644
--- a/tensorflow/core/profiler/internal/cpu/BUILD
+++ b/tensorflow/core/profiler/internal/cpu/BUILD
@@ -18,6 +18,7 @@
         "//tensorflow/core/profiler/utils:tf_op_utils",
         "//tensorflow/core/profiler/utils:xplane_builder",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -26,10 +27,10 @@
     srcs = ["host_tracer.cc"],
     deps = [
         ":host_tracer_utils",
-        "//tensorflow/core:core_cpu_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
         "//tensorflow/core/profiler/internal:profiler_factory",
         "//tensorflow/core/profiler/internal:profiler_interface",
         "//tensorflow/core/profiler/internal:traceme_recorder",
@@ -50,14 +51,17 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
         "//tensorflow/core/profiler/internal:profiler_interface",
         "//tensorflow/core/profiler/lib:profiler_session",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "//tensorflow/core/profiler/utils:xplane_schema",
         "//tensorflow/core/profiler/utils:xplane_visitor",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
-        "@com_google_googletest//:gtest_main",
+        "@com_google_googletest//:gtest",
     ],
 )
 
@@ -67,17 +71,14 @@
     copts = ["-fexceptions"],
     features = ["-use_header_modules"],
     deps = [
-        "//tensorflow/core:core_cpu_lib",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
         "//tensorflow/core/profiler/internal:profiler_factory",
         "//tensorflow/core/profiler/internal:profiler_interface",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
-        "//tensorflow/core/profiler/utils:xplane_schema",
-        "//tensorflow/core/profiler/utils:xplane_utils",
         "//tensorflow/python/profiler/internal:python_hooks",
-        "@com_google_absl//absl/strings",
     ],
     alwayslink = True,
 )
@@ -86,9 +87,12 @@
     name = "metadata_collector",
     srcs = ["metadata_collector.cc"],
     deps = [
+        "//tensorflow/compiler/xla/service:hlo_proto_cc",
         "//tensorflow/compiler/xla/service/gpu:gpu_debug_info_manager",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
         "//tensorflow/core/profiler/internal:profiler_factory",
         "//tensorflow/core/profiler/internal:profiler_interface",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
index 30b87c8..be1a7a2 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
@@ -12,18 +12,23 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <memory>
+#include <string>
 #include <utility>
 #include <vector>
 
 #include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/framework/step_stats.pb.h"
 #include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/internal/cpu/host_tracer_utils.h"
 #include "tensorflow/core/profiler/internal/profiler_factory.h"
 #include "tensorflow/core/profiler/internal/profiler_interface.h"
 #include "tensorflow/core/profiler/internal/traceme_recorder.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
@@ -119,8 +124,8 @@
           std::vector<absl::string_view> parts =
               absl::StrSplit(event.name, kUserMetadataMarker);
           if (parts.size() >= 2) {
-            ns->set_node_name(string(parts[0]));
-            ns->set_timeline_label(string(parts[1]));
+            ns->set_node_name(std::string(parts[0]));
+            ns->set_timeline_label(std::string(parts[1]));
           } else {
             ns->set_node_name(std::move(event.name));
           }
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
index e32ba92..499b7b6 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
@@ -12,17 +12,23 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <memory>
+#include <ostream>
 #include <string>
 
 #include <gmock/gmock.h>
-#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
 #include "tensorflow/core/framework/step_stats.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/internal/profiler_interface.h"
 #include "tensorflow/core/profiler/lib/profiler_session.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
@@ -38,13 +44,13 @@
 
 using ::testing::UnorderedElementsAre;
 
-NodeExecStats MakeNodeStats(const string& name, uint32 thread_id,
-                            const string& label = "") {
+NodeExecStats MakeNodeStats(absl::string_view name, uint32 thread_id,
+                            absl::string_view label = "") {
   NodeExecStats ns;
-  ns.set_node_name(name);
+  ns.set_node_name(std::string(name));
   ns.set_thread_id(thread_id);
   if (!label.empty()) {
-    ns.set_timeline_label(label);
+    ns.set_timeline_label(std::string(label));
   }
   return ns;
 }
@@ -109,7 +115,7 @@
 
 TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) {
   uint32 thread_id;
-  string thread_name = "MyThreadName";
+  std::string thread_name = "MyThreadName";
   XSpace space;
 
   // We start a thread with a known and controled name. As of the time of
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
index a4709ae..2e5d8ac 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
@@ -14,10 +14,13 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/internal/cpu/host_tracer_utils.h"
 
+#include <string>
 #include <utility>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/internal/parse_annotation.h"
 #include "tensorflow/core/profiler/internal/traceme_recorder.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc
index c6aa784..58da20a 100644
--- a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc
+++ b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc
@@ -13,17 +13,23 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <memory>
+#include <string>
 #include <utility>
 #include <vector>
 
 #include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/profiler/internal/profiler_factory.h"
 #include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/protobuf/config.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
index 103db6e..d684cb8 100644
--- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
@@ -12,18 +12,17 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include <utility>
-#include <vector>
+#include <memory>
 
-#include "absl/strings/str_split.h"
-#include "tensorflow/core/framework/step_stats.pb.h"
-#include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/profiler/internal/profiler_factory.h"
 #include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/util/env_var.h"
 #include "tensorflow/python/profiler/internal/python_hooks.h"
 
diff --git a/tensorflow/core/profiler/internal/parse_annotation.cc b/tensorflow/core/profiler/internal/parse_annotation.cc
index 2a3fa3f..32c26be 100644
--- a/tensorflow/core/profiler/internal/parse_annotation.cc
+++ b/tensorflow/core/profiler/internal/parse_annotation.cc
@@ -15,6 +15,9 @@
 #include "tensorflow/core/profiler/internal/parse_annotation.h"
 
 #include <stack>
+#include <string>
+#include <utility>
+#include <vector>
 
 #include "absl/strings/ascii.h"
 #include "absl/strings/str_split.h"
diff --git a/tensorflow/core/profiler/internal/parse_annotation.h b/tensorflow/core/profiler/internal/parse_annotation.h
index 6c2e536..bb0f122 100644
--- a/tensorflow/core/profiler/internal/parse_annotation.h
+++ b/tensorflow/core/profiler/internal/parse_annotation.h
@@ -16,7 +16,6 @@
 #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PARSE_ANNOTATION_H_
 #define TENSORFLOW_CORE_PROFILER_INTERNAL_PARSE_ANNOTATION_H_
 
-#include <utility>
 #include <vector>
 
 #include "absl/strings/string_view.h"
diff --git a/tensorflow/core/profiler/internal/parse_annotation_test.cc b/tensorflow/core/profiler/internal/parse_annotation_test.cc
index 4d4a2d5..e5d876a 100644
--- a/tensorflow/core/profiler/internal/parse_annotation_test.cc
+++ b/tensorflow/core/profiler/internal/parse_annotation_test.cc
@@ -14,6 +14,9 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/internal/parse_annotation.h"
 
+#include <vector>
+
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/internal/profiler_factory.cc b/tensorflow/core/profiler/internal/profiler_factory.cc
index e2bae59..5152e79 100644
--- a/tensorflow/core/profiler/internal/profiler_factory.cc
+++ b/tensorflow/core/profiler/internal/profiler_factory.cc
@@ -14,8 +14,14 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/internal/profiler_factory.h"
 
+#include <memory>
+#include <utility>
+#include <vector>
+
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/internal/profiler_factory.h b/tensorflow/core/profiler/internal/profiler_factory.h
index 6bcdcf2..c223d72 100644
--- a/tensorflow/core/profiler/internal/profiler_factory.h
+++ b/tensorflow/core/profiler/internal/profiler_factory.h
@@ -19,6 +19,7 @@
 #include <vector>
 
 #include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h
index 79dfc7a..9fe85e3 100644
--- a/tensorflow/core/profiler/internal/profiler_interface.h
+++ b/tensorflow/core/profiler/internal/profiler_interface.h
@@ -16,7 +16,6 @@
 #define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
 
 #include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/profiler/profiler_options.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 
diff --git a/tensorflow/core/profiler/internal/scoped_annotation_test.cc b/tensorflow/core/profiler/internal/scoped_annotation_test.cc
index 70a627f..50c1244 100644
--- a/tensorflow/core/profiler/internal/scoped_annotation_test.cc
+++ b/tensorflow/core/profiler/internal/scoped_annotation_test.cc
@@ -15,10 +15,11 @@
 
 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
 
+#include <string>
+
 #include "absl/strings/str_cat.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/internal/annotation_stack.h"
 
 namespace tensorflow {
@@ -48,11 +49,13 @@
   EXPECT_EQ(AnnotationStack::Get(), "");  // not enabled
 }
 
-string GenerateRandomString(int length) { return string(length, 'a'); }
+std::string GenerateRandomString(int length) {
+  return std::string(length, 'a');
+}
 
 void BM_ScopedAnnotationDisabled(int iters, int annotation_size) {
   testing::StopTiming();
-  string annotation = GenerateRandomString(annotation_size);
+  std::string annotation = GenerateRandomString(annotation_size);
   testing::StartTiming();
   for (int i = 0; i < iters; i++) {
     ScopedAnnotation trace(annotation);
@@ -64,7 +67,7 @@
 
 void BM_ScopedAnnotationEnabled(int iters, int annotation_size) {
   testing::StopTiming();
-  string annotation = GenerateRandomString(annotation_size);
+  std::string annotation = GenerateRandomString(annotation_size);
   AnnotationStack::Enable(true);
   testing::StartTiming();
   for (int i = 0; i < iters; i++) {
@@ -78,7 +81,7 @@
 
 void BM_ScopedAnnotationEnabled_Nested(int iters, int annotation_size) {
   testing::StopTiming();
-  string annotation = GenerateRandomString(annotation_size);
+  std::string annotation = GenerateRandomString(annotation_size);
   AnnotationStack::Enable(true);
   testing::StartTiming();
   for (int i = 0; i < iters; i++) {
diff --git a/tensorflow/core/profiler/internal/tfprof_stats.cc b/tensorflow/core/profiler/internal/tfprof_stats.cc
index 22b3bdc..56e6e2b 100644
--- a/tensorflow/core/profiler/internal/tfprof_stats.cc
+++ b/tensorflow/core/profiler/internal/tfprof_stats.cc
@@ -58,7 +58,6 @@
       ckpt_reader_(std::move(ckpt_reader)) {
   CHECK(graph) << "Must at least have GraphDef";
 
-  absl::PrintF("Parsing Inputs...\n");
   AddGraph(std::move(graph));
   if (run_meta && run_meta->has_step_stats()) {
     AddRunMeta(0, std::move(run_meta));
diff --git a/tensorflow/core/profiler/internal/traceme_recorder.cc b/tensorflow/core/profiler/internal/traceme_recorder.cc
index 365e399..268585b 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder.cc
+++ b/tensorflow/core/profiler/internal/traceme_recorder.cc
@@ -16,8 +16,18 @@
 
 #include <stddef.h>
 
+#include <algorithm>
+#include <atomic>
+#include <new>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/internal/traceme_recorder.h b/tensorflow/core/profiler/internal/traceme_recorder.h
index 8b5b32c..1da7d4c 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder.h
+++ b/tensorflow/core/profiler/internal/traceme_recorder.h
@@ -15,8 +15,6 @@
 #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_
 #define TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_
 
-#include <stddef.h>
-
 #include <atomic>
 #include <vector>
 
diff --git a/tensorflow/core/profiler/internal/traceme_recorder_test.cc b/tensorflow/core/profiler/internal/traceme_recorder_test.cc
index 9047888..8d7abc9 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder_test.cc
+++ b/tensorflow/core/profiler/internal/traceme_recorder_test.cc
@@ -15,19 +15,28 @@
 #include "tensorflow/core/profiler/internal/traceme_recorder.h"
 
 #include <atomic>
+#include <istream>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
 
 #include <gmock/gmock.h>
-#include <gtest/gtest.h>
 #include "absl/strings/str_cat.h"
-#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/threadpool.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace profiler {
 namespace {
 
+using ::testing::ElementsAre;
+
 MATCHER_P(Named, name, "") { return arg.name == name; }
 
 constexpr static uint64 kNanosInSec = 1000000000;
@@ -45,7 +54,7 @@
 
   ASSERT_EQ(results.size(), 1);
   EXPECT_THAT(results[0].events,
-              ::testing::ElementsAre(Named("during1"), Named("during2")));
+              ElementsAre(Named("during1"), Named("during2")));
 }
 
 void SpinNanos(int nanos) {
diff --git a/tensorflow/core/profiler/profiler_service.proto b/tensorflow/core/profiler/profiler_service.proto
index 37ca408..a096a10e 100644
--- a/tensorflow/core/profiler/profiler_service.proto
+++ b/tensorflow/core/profiler/profiler_service.proto
@@ -10,6 +10,10 @@
 service ProfilerService {
   // Starts a profiling session, blocks until it completes, and returns data.
   rpc Profile(ProfileRequest) returns (ProfileResponse) {}
+  // Signal to terminate the Profile rpc for a on-going profiling session,
+  // The Profile rpc will return successfully and prematurely without timeout.
+  // This is used by programmatic mode to end the session in workers.
+  rpc Terminate(TerminateRequest) returns (TerminateResponse) {}
   // Collects profiling data and returns user-friendly metrics.
   rpc Monitor(MonitorRequest) returns (MonitorResponse) {}
 }
@@ -81,6 +85,13 @@
   // next-field: 8
 }
 
+message TerminateRequest {
+  // Which session id to terminate.
+  string session_id = 1;
+}
+
+message TerminateResponse {}
+
 message MonitorRequest {
   // Duration for which to profile between each update.
   uint64 duration_ms = 1;
diff --git a/tensorflow/core/profiler/protobuf/overview_page.proto b/tensorflow/core/profiler/protobuf/overview_page.proto
index 8c83dbd..018aa75 100644
--- a/tensorflow/core/profiler/protobuf/overview_page.proto
+++ b/tensorflow/core/profiler/protobuf/overview_page.proto
@@ -84,6 +84,9 @@
   // A statement for output that recommends the next steps for investigating the
   // bottleneck.
   string output_statement = 9;
+  // A statement that recommends the next steps for investigating tf-function
+  // related bottleneck (it is a html so that it can link to other tools/docs.
+  string tf_function_statement_html = 10;
   // A list of tips for improving host performance.
   repeated OverviewPageTip host_tips = 3;
   // A list of tips for improving device performance.
diff --git a/tensorflow/core/profiler/protobuf/tf_function.proto b/tensorflow/core/profiler/protobuf/tf_function.proto
index fe07c00..1f5e153 100644
--- a/tensorflow/core/profiler/protobuf/tf_function.proto
+++ b/tensorflow/core/profiler/protobuf/tf_function.proto
@@ -49,6 +49,9 @@
   int64 total_tracing_count = 2;
   // Compiler used to compile this function.
   TfFunctionCompiler compiler = 3;
+  // Percentage of time spent in the expensive calls to this function in the
+  // profiled period.
+  double expensive_call_percent = 4;
 }
 
 // Statistics for all tf-functions.
diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD
index b5b631f..1e572df 100644
--- a/tensorflow/core/profiler/rpc/BUILD
+++ b/tensorflow/core/profiler/rpc/BUILD
@@ -19,9 +19,7 @@
         "//tensorflow/core/profiler/convert:xplane_to_profile_response",
         "//tensorflow/core/profiler/lib:profiler_session_headers",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
-        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/strings",
         tf_grpc_cc_dependency(),
     ],
 )
diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD
index bde5708..609f98a 100644
--- a/tensorflow/core/profiler/rpc/client/BUILD
+++ b/tensorflow/core/profiler/rpc/client/BUILD
@@ -12,7 +12,9 @@
     deps = [
         ":save_profile",
         "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/profiler:profiler_analysis_proto_cc",
+        "//tensorflow/core/profiler:profiler_options_proto_cc",
         "//tensorflow/core/profiler:profiler_service_proto_cc",
         "@com_google_absl//absl/strings",
         tf_grpc_cc_dependency(),
diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.cc b/tensorflow/core/profiler/rpc/client/capture_profile.cc
index 5335d18..a8642af 100644
--- a/tensorflow/core/profiler/rpc/client/capture_profile.cc
+++ b/tensorflow/core/profiler/rpc/client/capture_profile.cc
@@ -14,18 +14,25 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
 
+#include <iostream>
+#include <limits>
+#include <memory>
 #include <vector>
 
 #include "grpcpp/grpcpp.h"
-#include "absl/strings/escaping.h"
-#include "absl/strings/match.h"
 #include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
 #include "absl/strings/str_split.h"
 #include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/path.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/profiler_analysis.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_analysis.pb.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
+#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_service.pb.h"
 #include "tensorflow/core/profiler/rpc/client/save_profile.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.h b/tensorflow/core/profiler/rpc/client/capture_profile.h
index 1bde73f..c809d20 100644
--- a/tensorflow/core/profiler/rpc/client/capture_profile.h
+++ b/tensorflow/core/profiler/rpc/client/capture_profile.h
@@ -19,7 +19,7 @@
 
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/rpc/client/save_profile.cc b/tensorflow/core/profiler/rpc/client/save_profile.cc
index e328bf1..9cf2e29 100644
--- a/tensorflow/core/profiler/rpc/client/save_profile.cc
+++ b/tensorflow/core/profiler/rpc/client/save_profile.cc
@@ -17,7 +17,6 @@
 
 #include <initializer_list>
 #include <memory>
-#include <ostream>
 #include <sstream>
 #include <string>
 #include <vector>
diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc
index 4d2f3c3..f05a829 100644
--- a/tensorflow/core/profiler/rpc/profiler_server.cc
+++ b/tensorflow/core/profiler/rpc/profiler_server.cc
@@ -16,17 +16,19 @@
 #include "tensorflow/core/profiler/rpc/profiler_server.h"
 
 #include <memory>
-#include <utility>
+#include <string>
 
 #include "grpcpp/grpcpp.h"
 #include "absl/strings/str_cat.h"
-#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
 
 namespace tensorflow {
 
 void ProfilerServer::StartProfilerServer(int32 port) {
-  string server_address = absl::StrCat("0.0.0.0:", port);
+  std::string server_address = absl::StrCat("0.0.0.0:", port);
   service_ = CreateProfilerService();
   ::grpc::ServerBuilder builder;
   builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
index 555f4c3..0a234d7 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
@@ -15,18 +15,23 @@
 
 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
 
+#include <memory>
+
 #include "grpcpp/support/status.h"
-#include "absl/container/flat_hash_set.h"
+#include "absl/container/flat_hash_map.h"
 #include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
 #include "tensorflow/core/profiler/internal/profiler_interface.h"
 #include "tensorflow/core/profiler/lib/profiler_session.h"
+#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_service.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 
 namespace tensorflow {
@@ -61,11 +66,16 @@
     }
 
     Env* env = Env::Default();
-    for (size_t i = 0; i < req->duration_ms(); ++i) {
+    for (uint64 i = 0; i < req->duration_ms(); ++i) {
       env->SleepForMicroseconds(EnvTime::kMillisToMicros);
       if (ctx->IsCancelled()) {
         return ::grpc::Status::CANCELLED;
       }
+      if (TF_PREDICT_FALSE(IsStopped(req->session_id()))) {
+        mutex_lock lock(mutex_);
+        stop_signals_per_session_.erase(req->session_id());
+        break;
+      }
     }
 
     status = CollectDataToResponse(*req, profiler.get(), response);
@@ -76,6 +86,25 @@
 
     return ::grpc::Status::OK;
   }
+
+  ::grpc::Status Terminate(::grpc::ServerContext* ctx,
+                           const TerminateRequest* req,
+                           TerminateResponse* response) override {
+    mutex_lock lock(mutex_);
+    stop_signals_per_session_[req->session_id()] = true;
+    return ::grpc::Status::OK;
+  }
+
+ private:
+  bool IsStopped(const std::string& session_id) {
+    mutex_lock lock(mutex_);
+    auto it = stop_signals_per_session_.find(session_id);
+    return it != stop_signals_per_session_.end() && it->second;
+  }
+
+  mutex mutex_;
+  absl::flat_hash_map<std::string, bool> stop_signals_per_session_
+      GUARDED_BY(mutex_);
 };
 
 }  // namespace
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.h b/tensorflow/core/profiler/rpc/profiler_service_impl.h
index 4a7636c..00a850a 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.h
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.h
@@ -15,10 +15,8 @@
 #ifndef TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_
 #define TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_
 
-#include "grpcpp/grpcpp.h"
-#include "grpcpp/server_context.h"
-#include "grpcpp/support/status.h"
-#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include <memory>
+
 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD
index ad26dcc..ca20236 100644
--- a/tensorflow/core/profiler/utils/BUILD
+++ b/tensorflow/core/profiler/utils/BUILD
@@ -30,6 +30,7 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
     ],
@@ -51,6 +52,14 @@
 )
 
 cc_library(
+    name = "html_utils",
+    hdrs = ["html_utils.h"],
+    deps = [
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+cc_library(
     name = "op_metrics_db_utils",
     srcs = ["op_metrics_db_utils.cc"],
     hdrs = ["op_metrics_db_utils.h"],
@@ -83,7 +92,6 @@
     hdrs = ["tf_op_utils.h"],
     deps = [
         "//tensorflow/core:regexp_internal",
-        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -96,6 +104,7 @@
         ":tf_op_utils",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -156,6 +165,7 @@
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -170,7 +180,6 @@
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
-        "@com_google_absl//absl/types:span",
     ],
 )
 
@@ -196,7 +205,6 @@
     name = "xplane_utils_test",
     srcs = ["xplane_utils_test.cc"],
     deps = [
-        ":time_utils",
         ":xplane_builder",
         ":xplane_utils",
         ":xplane_visitor",
@@ -205,6 +213,8 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -232,6 +242,7 @@
     deps = [
         ":xplane_schema",
         ":xplane_visitor",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
     ],
 )
 
@@ -243,9 +254,11 @@
     deps = [
         ":tf_op_utils",
         ":tf_xplane_visitor",
+        ":xplane_builder",
         ":xplane_schema",
         ":xplane_utils",
         ":xplane_visitor",
+        "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -263,10 +276,13 @@
         ":xplane_builder",
         ":xplane_schema",
         ":xplane_utils",
+        ":xplane_visitor",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -281,10 +297,13 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/grappler/costs:cost_estimator",
+        "//tensorflow/core/grappler/costs:op_context",
         "//tensorflow/core/grappler/costs:op_level_cost_estimator",
         "//tensorflow/core/grappler/costs:op_performance_data_cc",
-        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -296,6 +315,7 @@
         ":group_events",
         ":tf_op_utils",
         ":tf_xplane_visitor",
+        ":time_utils",
         ":timespan",
         ":trace_utils",
         ":xplane_builder",
@@ -305,8 +325,10 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -321,6 +343,8 @@
         ":xplane_builder",
         ":xplane_schema",
         ":xplane_utils",
+        ":xplane_visitor",
+        "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
@@ -347,10 +371,10 @@
         ":xplane_builder",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/framework:protos_all_cc",
         "//tensorflow/core/profiler/protobuf:tfstreamz_proto_cc",
         "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/strings:str_format",
     ],
 )
diff --git a/tensorflow/core/profiler/utils/cost_utils.cc b/tensorflow/core/profiler/utils/cost_utils.cc
index 754aa65..a94f09b 100644
--- a/tensorflow/core/profiler/utils/cost_utils.cc
+++ b/tensorflow/core/profiler/utils/cost_utils.cc
@@ -15,12 +15,27 @@
 
 #include "tensorflow/core/profiler/utils/cost_utils.h"
 
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_context.h"
 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/cost_utils.h b/tensorflow/core/profiler/utils/cost_utils.h
index f109555..a778bca 100644
--- a/tensorflow/core/profiler/utils/cost_utils.h
+++ b/tensorflow/core/profiler/utils/cost_utils.h
@@ -15,12 +15,13 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_
 
-#include <set>
+#include <string>
 
-#include "absl/strings/string_view.h"
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
 #include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
@@ -46,7 +47,8 @@
   OpRoofLineStats Predict(const XEventVisitor& event);
 
  private:
-  std::set<string> unsupported_ops_;  // summary for unsupported ops.
+  absl::flat_hash_set<std::string>
+      unsupported_ops_;  // summary for unsupported ops.
 
   TF_DISALLOW_COPY_AND_ASSIGN(TfOpRoofLineCostEstimator);
 };
diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc
index c99d8e8..112c097 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.cc
+++ b/tensorflow/core/profiler/utils/derived_timeline.cc
@@ -14,15 +14,27 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/derived_timeline.h"
 
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/group_events.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/time_utils.h"
 #include "tensorflow/core/profiler/utils/timespan.h"
 #include "tensorflow/core/profiler/utils/trace_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h
index 61b62bd..cd4da79 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.h
+++ b/tensorflow/core/profiler/utils/derived_timeline.h
@@ -15,7 +15,13 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_
 
+#include <functional>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/group_events.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc
index f3e6b66..76a0188 100644
--- a/tensorflow/core/profiler/utils/derived_timeline_test.cc
+++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc
@@ -15,8 +15,9 @@
 
 #include "tensorflow/core/profiler/utils/derived_timeline.h"
 
-#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/group_events.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
@@ -24,6 +25,7 @@
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/errors.cc b/tensorflow/core/profiler/utils/errors.cc
index d829ee0..9c678e9 100644
--- a/tensorflow/core/profiler/utils/errors.cc
+++ b/tensorflow/core/profiler/utils/errors.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/core/profiler/utils/errors.h"
 
+#include "absl/strings/string_view.h"
+
 namespace tensorflow {
 namespace profiler {
 
diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc
index 9768331..5e0413c 100644
--- a/tensorflow/core/profiler/utils/event_span.cc
+++ b/tensorflow/core/profiler/utils/event_span.cc
@@ -14,11 +14,19 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/event_span.h"
 
+#include <string>
+#include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h
index 36b3172..1adc6a7 100644
--- a/tensorflow/core/profiler/utils/event_span.h
+++ b/tensorflow/core/profiler/utils/event_span.h
@@ -16,10 +16,11 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_
 
+#include <string>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
-#include "tensorflow/core/platform/logging.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/utils/timespan.h"
diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc
index 60d12c0..4296149 100644
--- a/tensorflow/core/profiler/utils/group_events.cc
+++ b/tensorflow/core/profiler/utils/group_events.cc
@@ -15,13 +15,25 @@
 
 #include "tensorflow/core/profiler/utils/group_events.h"
 
-#include <stack>
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
 
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/group_events.h b/tensorflow/core/profiler/utils/group_events.h
index 1140f2d..4b6fc58 100644
--- a/tensorflow/core/profiler/utils/group_events.h
+++ b/tensorflow/core/profiler/utils/group_events.h
@@ -16,9 +16,16 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_GROUP_EVENTS_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_GROUP_EVENTS_H_
 
+#include <functional>
 #include <memory>
+#include <string>
+#include <vector>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
diff --git a/tensorflow/core/profiler/utils/group_events_test.cc b/tensorflow/core/profiler/utils/group_events_test.cc
index 6b6a0d2..11996ba 100644
--- a/tensorflow/core/profiler/utils/group_events_test.cc
+++ b/tensorflow/core/profiler/utils/group_events_test.cc
@@ -16,12 +16,15 @@
 #include "tensorflow/core/profiler/utils/group_events.h"
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc
index 75896c0..e2a4004 100644
--- a/tensorflow/core/profiler/utils/hardware_type_utils.cc
+++ b/tensorflow/core/profiler/utils/hardware_type_utils.cc
@@ -17,6 +17,7 @@
 
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/html_utils.h b/tensorflow/core/profiler/utils/html_utils.h
new file mode 100644
index 0000000..215d9f5
--- /dev/null
+++ b/tensorflow/core/profiler/utils/html_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_
+#define TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_
+
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Creates a html that links to the given url with the given text.
+inline std::string AnchorElement(absl::string_view url,
+                                 absl::string_view text) {
+  return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
+}
+
+}  // namespace profiler
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_
diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc
index 14038d5..c40c3a8 100644
--- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc
+++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc
@@ -15,15 +15,17 @@
 
 #include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
 
+#include <algorithm>
+#include <string>
 #include <tuple>
 #include <vector>
 
 #include "absl/strings/match.h"
 #include "absl/strings/numbers.h"
-#include "absl/strings/str_cat.h"
 #include "absl/strings/str_split.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
 
 namespace tensorflow {
@@ -34,15 +36,15 @@
   const std::vector<absl::string_view> params =
       absl::StrSplit(xstat_kernel_details, absl::ByAnyChar(":\n"));
 
-  constexpr uint32_t kNumDimensions = 3;
-  for (uint32_t dim = 0; dim < kNumDimensions; ++dim) {
+  constexpr uint32 kNumDimensions = 3;
+  for (uint32 dim = 0; dim < kNumDimensions; ++dim) {
     kernel->add_block_dim(1);
     kernel->add_grid_dim(1);
   }
 
   // Process value pairs.
-  for (uint32_t ii = 0; ii < params.size(); ii += 2) {
-    uint32_t value = 0;
+  for (uint32 ii = 0; ii < params.size(); ii += 2) {
+    uint32 value = 0;
     if (params[ii] == "registers_per_thread" &&
         absl::SimpleAtoi(params[ii + 1], &value)) {
       kernel->set_registers_per_thread(value);
diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc
index 06307d6..863d2f7 100644
--- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc
+++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc
@@ -15,8 +15,13 @@
 
 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
 
+#include <algorithm>
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/utils/math_utils.h"
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
@@ -40,7 +45,7 @@
         /*hlo_module_id=*/0, tf_op_name);
     if (tf_op_metrics->category().empty()) {
       tf_op_metrics->set_category(
-          tf_op_type == kUnknownOp ? "Unknown" : string(tf_op_type));
+          tf_op_type == kUnknownOp ? "Unknown" : std::string(tf_op_type));
     }
     tf_op_metrics->set_is_eager(device_op_metrics.is_eager());
     // The occurrences of a TF-op is the maximum among the occurrences of all
@@ -89,8 +94,8 @@
 void AddIdleOp(OpMetricsDb* db) {
   uint64 idle_time_ps = IdleTimePs(*db);
   OpMetrics* metrics = db->add_metrics_db();
-  metrics->set_name(string(kIdle));
-  metrics->set_category(string(kIdle));
+  metrics->set_name(std::string(kIdle));
+  metrics->set_category(std::string(kIdle));
   metrics->set_occurrences(0);
   metrics->set_time_ps(idle_time_ps);
   metrics->set_self_time_ps(idle_time_ps);
diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc
index 74ce13d..921e061 100644
--- a/tensorflow/core/profiler/utils/op_utils.cc
+++ b/tensorflow/core/profiler/utils/op_utils.cc
@@ -15,8 +15,14 @@
 
 #include "tensorflow/core/profiler/utils/op_utils.h"
 
+#include <algorithm>
+#include <string>
+
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/utils/tf_op_utils.h"
 
 namespace tensorflow {
 namespace profiler {
@@ -69,9 +75,9 @@
   OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name);
   if (op_metrics->category().empty())
     op_metrics->set_category(category == kUnknownOp ? "unknown"
-                                                    : string(category));
+                                                    : std::string(category));
   if (op_metrics->provenance().empty())
-    op_metrics->set_provenance(string(provenance));
+    op_metrics->set_provenance(std::string(provenance));
   op_metrics->set_is_eager(op_metrics->is_eager() || is_eager);
   op_metrics->set_occurrences(op_metrics->occurrences() + occurrences);
   op_metrics->set_time_ps(op_metrics->time_ps() + time_ps);
diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h
index 8aaa0f4..f94328d 100644
--- a/tensorflow/core/profiler/utils/op_utils.h
+++ b/tensorflow/core/profiler/utils/op_utils.h
@@ -16,13 +16,10 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_
 
-#include <string>
-
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
-#include "tensorflow/core/profiler/utils/tf_op_utils.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/tf_op_utils.cc b/tensorflow/core/profiler/utils/tf_op_utils.cc
index 5a42044..630a74c 100644
--- a/tensorflow/core/profiler/utils/tf_op_utils.cc
+++ b/tensorflow/core/profiler/utils/tf_op_utils.cc
@@ -15,11 +15,14 @@
 
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
 
+#include <string>
+#include <vector>
+
 #include "absl/strings/ascii.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_split.h"
-#include "absl/strings/strip.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/regexp.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/tf_op_utils.h b/tensorflow/core/profiler/utils/tf_op_utils.h
index d1ac69e..b8af946 100644
--- a/tensorflow/core/profiler/utils/tf_op_utils.h
+++ b/tensorflow/core/profiler/utils/tf_op_utils.h
@@ -16,9 +16,9 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
 
+#include <string>
 #include <vector>
 
-#include "absl/base/attributes.h"
 #include "absl/strings/match.h"
 #include "absl/strings/string_view.h"
 
diff --git a/tensorflow/core/profiler/utils/tf_op_utils_test.cc b/tensorflow/core/profiler/utils/tf_op_utils_test.cc
index fa51695..136dbee 100644
--- a/tensorflow/core/profiler/utils/tf_op_utils_test.cc
+++ b/tensorflow/core/profiler/utils/tf_op_utils_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
 
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/tf_xplane_visitor.h b/tensorflow/core/profiler/utils/tf_xplane_visitor.h
index 33a170f..17a7b94 100644
--- a/tensorflow/core/profiler/utils/tf_xplane_visitor.h
+++ b/tensorflow/core/profiler/utils/tf_xplane_visitor.h
@@ -16,6 +16,7 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TF_XPLANE_VISITOR_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_TF_XPLANE_VISITOR_H_
 
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.cc b/tensorflow/core/profiler/utils/tfstreamz_utils.cc
index 493420f..f4cbaa8 100644
--- a/tensorflow/core/profiler/utils/tfstreamz_utils.cc
+++ b/tensorflow/core/profiler/utils/tfstreamz_utils.cc
@@ -14,36 +14,46 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/tfstreamz_utils.h"
 
+#include <map>
 #include <memory>
+#include <string>
+#include <utility>
+#include <vector>
 
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/substitute.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/framework/summary.pb.h"
 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
-#include "tensorflow/core/lib/monitoring/collection_registry.h"
-#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/lib/monitoring/metric_def.h"
+#include "tensorflow/core/lib/monitoring/types.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/tfstreamz.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/xplane_builder.h"
 
 namespace tensorflow {
 namespace profiler {
 
 namespace {
-string ConstructXStatName(const string& name, const monitoring::Point& point) {
+
+std::string ConstructXStatName(absl::string_view name,
+                               const monitoring::Point& point) {
   if (point.labels.empty()) {
-    return name;
+    return std::string(name);
   }
   return absl::Substitute(
       "$0{$1}", name,
-      absl::StrJoin(point.labels, ", ",
-                    [](string* out, const monitoring::Point::Label& label) {
-                      absl::StrAppend(out, label.name, "=", label.value);
-                    }));
+      absl::StrJoin(
+          point.labels, ", ",
+          [](std::string* out, const monitoring::Point::Label& label) {
+            absl::StrAppend(out, label.name, "=", label.value);
+          }));
 }
 
-string SerializePercentile(const monitoring::Percentiles& percentiles) {
+std::string SerializePercentile(const monitoring::Percentiles& percentiles) {
   tfstreamz::Percentiles output;
   output.set_unit_of_measure(
       static_cast<tfstreamz::UnitOfMeasure>(percentiles.unit_of_measure));
@@ -81,11 +91,11 @@
     xevent.SetEndTimestampNs(snapshot.end_time_ns);
     auto& metric_descriptor_map = snapshot.metrics->metric_descriptor_map;
     for (const auto& point_set : snapshot.metrics->point_set_map) {
-      const string& metric_name = point_set.first;
+      const std::string& metric_name = point_set.first;
       // Each metrics have multiple points corresponding to different labels.
       for (const auto& point : point_set.second->points) {
         // Generates one KPI metric for each point.
-        string stat_name = ConstructXStatName(metric_name, *point);
+        std::string stat_name = ConstructXStatName(metric_name, *point);
         auto* metadata = xplane.GetOrCreateStatMetadata(stat_name);
         auto it = metric_descriptor_map.find(metric_name);
         if (it != metric_descriptor_map.end()) {
diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.h b/tensorflow/core/profiler/utils/tfstreamz_utils.h
index ae8e407..1ab21ed 100644
--- a/tensorflow/core/profiler/utils/tfstreamz_utils.h
+++ b/tensorflow/core/profiler/utils/tfstreamz_utils.h
@@ -15,11 +15,13 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_
 
+#include <memory>
+#include <vector>
+
 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
-#include "tensorflow/core/lib/monitoring/collection_registry.h"
 #include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
-#include "tensorflow/core/profiler/utils/xplane_builder.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/timespan.h b/tensorflow/core/profiler/utils/timespan.h
index bccbeaa..82775af 100644
--- a/tensorflow/core/profiler/utils/timespan.h
+++ b/tensorflow/core/profiler/utils/timespan.h
@@ -16,6 +16,9 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TIMESPAN_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_TIMESPAN_H_
 
+#include <algorithm>
+#include <string>
+
 #include "absl/strings/str_cat.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/profiler/utils/xplane_builder.cc b/tensorflow/core/profiler/utils/xplane_builder.cc
index 9e66a15..f923f39 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.cc
+++ b/tensorflow/core/profiler/utils/xplane_builder.cc
@@ -14,6 +14,14 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 
+#include <algorithm>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/time_utils.h"
 
 namespace tensorflow {
@@ -54,7 +62,7 @@
   return metadata;
 }
 
-XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata(string&& name) {
+XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata(std::string&& name) {
   XEventMetadata*& metadata = event_metadata_by_name_[name];
   if (metadata == nullptr) {
     metadata =
diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h
index 803cc7b..b0d743a 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.h
+++ b/tensorflow/core/profiler/utils/xplane_builder.h
@@ -15,10 +15,15 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_BUILDER_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_BUILDER_H_
 
+#include <stddef.h>
+
+#include <string>
+#include <utility>
+
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/numbers.h"
 #include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/time_utils.h"
@@ -53,12 +58,12 @@
   void AddStatValue(const XStatMetadata& metadata, absl::string_view value,
                     bool is_bytes = false) {
     if (is_bytes) {
-      AddStat(metadata)->set_bytes_value(string(value));
+      AddStat(metadata)->set_bytes_value(std::string(value));
     } else {
-      AddStat(metadata)->set_str_value(string(value));
+      AddStat(metadata)->set_str_value(std::string(value));
     }
   }
-  void AddStatValue(const XStatMetadata& metadata, string&& value,
+  void AddStatValue(const XStatMetadata& metadata, std::string&& value,
                     bool is_bytes = false) {
     if (is_bytes) {
       AddStat(metadata)->set_bytes_value(std::move(value));
@@ -160,7 +165,7 @@
 
   int64 NumEvents() { return line_->events_size(); }
 
-  void SetName(absl::string_view name) { line_->set_name(string(name)); }
+  void SetName(absl::string_view name) { line_->set_name(std::string(name)); }
 
   void SetNameIfEmpty(absl::string_view name) {
     if (line_->name().empty()) SetName(name);
@@ -205,7 +210,7 @@
   int64 Id() { return plane_->id(); }
   void SetId(int64 id) { plane_->set_id(id); }
 
-  void SetName(absl::string_view name) { plane_->set_name(string(name)); }
+  void SetName(absl::string_view name) { plane_->set_name(std::string(name)); }
 
   void ReserveLines(size_t num_lines) {
     plane_->mutable_lines()->Reserve(num_lines);
@@ -222,7 +227,7 @@
 
   XEventMetadata* GetOrCreateEventMetadata(int64 metadata_id);
   XEventMetadata* GetOrCreateEventMetadata(absl::string_view name);
-  XEventMetadata* GetOrCreateEventMetadata(string&& name);
+  XEventMetadata* GetOrCreateEventMetadata(std::string&& name);
   inline XEventMetadata* GetOrCreateEventMetadata(const char* name) {
     return GetOrCreateEventMetadata(absl::string_view(name));
   }
@@ -251,7 +256,7 @@
   if (stat.value_case() == XStat::kRefValue) {
     const auto& stat_metadata_map = src.stat_metadata();
     const auto it = stat_metadata_map.find(stat.ref_value());
-    if (ABSL_PREDICT_FALSE(it == stat_metadata_map.end())) {
+    if (TF_PREDICT_FALSE(it == stat_metadata_map.end())) {
       // the reference value in stat is not found in XStatMetadata from src.
       return;
     }
diff --git a/tensorflow/core/profiler/utils/xplane_builder_test.cc b/tensorflow/core/profiler/utils/xplane_builder_test.cc
index cb87497..e55e01d 100644
--- a/tensorflow/core/profiler/utils/xplane_builder_test.cc
+++ b/tensorflow/core/profiler/utils/xplane_builder_test.cc
@@ -14,7 +14,11 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 
+#include <string>
+
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc
index 51bc4d0..f8ff31b 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.cc
+++ b/tensorflow/core/profiler/utils/xplane_schema.cc
@@ -17,7 +17,10 @@
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index 97e54a7..31ff901 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -16,11 +16,10 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
 
-#include "absl/strings/match.h"
 #include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
-#include "absl/types/span.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc
index b2cc1fd..7f5221c 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.cc
+++ b/tensorflow/core/profiler/utils/xplane_utils.cc
@@ -14,12 +14,21 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
 
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/timespan.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
+#include "tensorflow/core/profiler/utils/xplane_schema.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h
index 4f0a8b8..49087c4 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.h
+++ b/tensorflow/core/profiler/utils/xplane_utils.h
@@ -17,6 +17,7 @@
 
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/utils/xplane_utils_test.cc b/tensorflow/core/profiler/utils/xplane_utils_test.cc
index b9b15b2..04e06fc 100644
--- a/tensorflow/core/profiler/utils/xplane_utils_test.cc
+++ b/tensorflow/core/profiler/utils/xplane_utils_test.cc
@@ -15,9 +15,14 @@
 
 #include "tensorflow/core/profiler/utils/xplane_utils.h"
 
+#include <string>
+
 #include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/platform/env_time.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/profiler/utils/xplane_builder.h"
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/xplane_visitor.cc b/tensorflow/core/profiler/utils/xplane_visitor.cc
index ab97271..42068b7 100644
--- a/tensorflow/core/profiler/utils/xplane_visitor.cc
+++ b/tensorflow/core/profiler/utils/xplane_visitor.cc
@@ -14,7 +14,16 @@
 ==============================================================================*/
 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
 
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 
 namespace tensorflow {
 namespace profiler {
diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h
index 8cd805c..4120a28 100644
--- a/tensorflow/core/profiler/utils/xplane_visitor.h
+++ b/tensorflow/core/profiler/utils/xplane_visitor.h
@@ -15,8 +15,11 @@
 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_
 #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_
 
+#include <stddef.h>
+
 #include <functional>
-#include <utility>
+#include <string>
+#include <vector>
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/string_view.h"
diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD
new file mode 100644
index 0000000..a374c80
--- /dev/null
+++ b/tensorflow/core/protobuf/BUILD
@@ -0,0 +1,182 @@
+# For platform specific build config
+load(
+    "//tensorflow/core/platform:build_config.bzl",
+    "tf_additional_all_protos",
+    "tf_proto_library",
+    "tf_proto_library_cc",
+    "tf_pyclif_proto_library",
+)
+
+package(
+    default_visibility = [
+        "//tensorflow:internal",
+        "//tensorflow/core:__subpackages__",
+        "//tensorflow_models:__subpackages__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+COMMON_PROTO_SRCS = [
+    "bfc_memory_map.proto",
+    "config.proto",
+    "cluster.proto",
+    "debug.proto",
+    "device_filters.proto",
+    "device_properties.proto",
+    "graph_debug_info.proto",
+    "queue_runner.proto",
+    "rewriter_config.proto",
+    "tensor_bundle.proto",
+    "saver.proto",
+    "verifier_config.proto",
+]
+
+[
+    [
+        tf_pyclif_proto_library(
+            name = "%s_pyclif" % proto_name,
+            proto_lib = ":for_core_protos",
+            proto_srcfile = "%s.proto" % proto_name,
+            visibility = ["//visibility:public"],
+        ),
+    ]
+    for proto_name in [
+        "config",
+        "device_properties",
+        "graph_debug_info",
+        "meta_graph",
+        "saved_model",
+    ]
+]
+
+tf_proto_library(
+    name = "autotuning_proto",
+    srcs = ["autotuning.proto"],
+    cc_api_version = 2,
+    make_default_target_header_only = True,
+)
+
+tf_proto_library(
+    name = "conv_autotuning_proto",
+    srcs = ["conv_autotuning.proto"],
+    cc_api_version = 2,
+    make_default_target_header_only = True,
+    protodeps = [
+        "//tensorflow/stream_executor:dnn_proto",
+    ],
+)
+
+tf_proto_library_cc(
+    name = "worker_proto",
+    srcs = ["worker.proto"],
+    cc_api_version = 2,
+    protodeps = tf_additional_all_protos(),
+    visibility = ["//visibility:public"],
+)
+
+tf_proto_library_cc(
+    name = "worker_service_proto",
+    srcs = ["worker_service.proto"],
+    has_services = 1,
+    cc_api_version = 2,
+    cc_stubby_versions = ["2"],
+    protodeps = [":worker_proto"],
+)
+
+tf_proto_library_cc(
+    name = "master_proto",
+    srcs = ["master.proto"],
+    cc_api_version = 2,
+    protodeps = tf_additional_all_protos(),
+    visibility = ["//tensorflow:internal"],
+)
+
+tf_proto_library_cc(
+    name = "master_service_proto",
+    srcs = ["master_service.proto"],
+    has_services = 1,
+    cc_api_version = 2,
+    cc_stubby_versions = ["2"],
+    protodeps = [":master_proto"],
+)
+
+tf_proto_library_cc(
+    name = "eager_service_proto",
+    srcs = ["eager_service.proto"],
+    has_services = 1,
+    cc_api_version = 2,
+    cc_grpc_version = 1,
+    cc_stubby_versions = ["2"],
+    protodeps = tf_additional_all_protos(),
+)
+
+tf_proto_library_cc(
+    name = "replay_log_proto",
+    srcs = ["replay_log.proto"],
+    cc_api_version = 2,
+    protodeps = [
+        ":master_proto",
+    ] + tf_additional_all_protos(),
+)
+
+tf_proto_library(
+    name = "error_codes_proto_impl",
+    srcs = ["error_codes.proto"],
+    cc_api_version = 2,
+    make_default_target_header_only = True,
+)
+
+exports_files(
+    srcs = ["error_codes.proto"] + COMMON_PROTO_SRCS + [
+        # Protos which are not needed on mobile builds, but should be included
+        # in protos_all.
+        #
+        # Note that some protos are in neither core_proto_srcs nor this
+        # filegroup; e.g. ones with individual proto_library targets.
+        "control_flow.proto",
+        # TODO(ebrevdo): Re-enable once CriticalSection is in core.
+        # "critical_section.proto",
+        "data/experimental/snapshot.proto",
+        "debug_event.proto",
+        "meta_graph.proto",
+        "named_tensor.proto",
+        "remote_tensor_handle.proto",
+        "saved_model.proto",
+        "saved_object_graph.proto",
+        "struct.proto",
+        "tensorflow_server.proto",
+        "trackable_object_graph.proto",
+        "transport_options.proto",
+    ],
+)
+
+tf_proto_library(
+    name = "for_core_protos",
+    srcs = COMMON_PROTO_SRCS + [
+        # Protos which are not needed on mobile builds, but should be included
+        # in protos_all.
+        #
+        # Note that some protos are in neither core_proto_srcs nor this
+        # filegroup; e.g. ones with individual proto_library targets.
+        "control_flow.proto",
+        # TODO(ebrevdo): Re-enable once CriticalSection is in core.
+        # "critical_section.proto",
+        "data/experimental/snapshot.proto",
+        "debug_event.proto",
+        "meta_graph.proto",
+        "named_tensor.proto",
+        "remote_tensor_handle.proto",
+        "saved_model.proto",
+        "saved_object_graph.proto",
+        "struct.proto",
+        "tensorflow_server.proto",
+        "trackable_object_graph.proto",
+        "transport_options.proto",
+    ],
+    cc_api_version = 2,
+    make_default_target_header_only = True,
+    protodeps = [
+        ":error_codes_proto_impl",
+        "//tensorflow/core/framework:protos_all",
+    ],
+)
diff --git a/tensorflow/core/protobuf/remote_tensor_handle.proto b/tensorflow/core/protobuf/remote_tensor_handle.proto
index 1099522..36e3f81 100644
--- a/tensorflow/core/protobuf/remote_tensor_handle.proto
+++ b/tensorflow/core/protobuf/remote_tensor_handle.proto
@@ -21,11 +21,11 @@
   int64 op_id = 1;
   // The index into the outputs of the operation that produced this tensor.
   int32 output_num = 2;
-  // Device of the operation that produced this tensor. Cannot be empty.
+  // Device where the tensor is located. Cannot be empty.
   // For multi-device functions, it's the default device passed to placer.
   string device = 3;
-  // Device where the tensor is located. Can be empty if the operation producing
-  // this tensor is a multi-device function.
+  // Device of the operation producing this tensor. Can be empty if the
+  // operation producing this tensor is a multi-device function.
   string op_device = 4;
   // Tensor type.
   DataType dtype = 5;
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index b7f2b76..23e6138 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -21,7 +21,7 @@
 // Also update tensorflow/tensorflow.bzl and
 // tensorflow/tools/pip_package/setup.py
 #define TF_MAJOR_VERSION 2
-#define TF_MINOR_VERSION 1
+#define TF_MINOR_VERSION 2
 #define TF_PATCH_VERSION 0
 
 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
@@ -108,7 +108,7 @@
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 397  // Updated: 2020/5/10
+#define TF_GRAPH_DEF_VERSION 399  // Updated: 2020/5/12
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 43b2d93..4ea5fc3 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -51,3 +51,11 @@
         "//tensorflow/core:lib",
     ],
 )
+
+cc_library(
+    name = "tpu_config_c_api",
+    hdrs = ["tpu_config_c_api.h"],
+    deps = [
+        "//tensorflow/c:tf_status",
+    ],
+)
diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h
new file mode 100644
index 0000000..334a6a1
--- /dev/null
+++ b/tensorflow/core/tpu/tpu_config_c_api.h
@@ -0,0 +1,54 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
+#define TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
+
+#include <cstddef>
+
+#include "tensorflow/c/tf_status.h"
+
+typedef struct TpuSerializedProto TpuSerializedProto;
+
+extern "C" {
+
+bool TPUHostInitialized();
+
+// TODO(frankchn): Modify API to take in raw values instead of Tensors.
+void ConfigureDistributedTpuOp_DoWork(size_t input_size,
+                                      TpuSerializedProto** inputs,
+                                      TpuSerializedProto* output,
+                                      TF_Status* status);
+
+void WaitForDistributedTpuOp_DoWork(size_t input_size,
+                                    TpuSerializedProto** inputs,
+                                    TpuSerializedProto* output,
+                                    TF_Status* status);
+
+void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
+
+void InitializeHostForDistributedTpuOp_DoWork(
+    size_t input_size, TpuSerializedProto** inputs,
+    bool enable_whole_mesh_compilations, TpuSerializedProto* output,
+    TF_Status* status);
+
+void SetGlobalTPUArrayOp_DoWork(size_t input_size, TpuSerializedProto** inputs,
+                                TF_Status* status);
+
+void DisconnectDistributedTpuChipsOp_DoWork(TpuSerializedProto* output,
+                                            TF_Status* status);
+}
+
+#endif  // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
diff --git a/tensorflow/core/util/debug_events_writer.cc b/tensorflow/core/util/debug_events_writer.cc
index 595f92d..d9c3393 100644
--- a/tensorflow/core/util/debug_events_writer.cc
+++ b/tensorflow/core/util/debug_events_writer.cc
@@ -179,7 +179,7 @@
   metadata->set_tensorflow_version(TF_VERSION_STRING);
   metadata->set_file_version(
       strings::Printf("%s%d", kVersionPrefix, kCurrentFormatVersion));
-  SerializeAndWriteDebugEvent(&debug_event, METADATA);
+  TF_RETURN_IF_ERROR(SerializeAndWriteDebugEvent(&debug_event, METADATA));
   TF_RETURN_WITH_CONTEXT_IF_ERROR(
       metadata_writer_->Flush(), "Failed to flush debug event metadata writer");
 
@@ -189,38 +189,38 @@
   return Status::OK();
 }
 
-void DebugEventsWriter::WriteSourceFile(SourceFile* source_file) {
+Status DebugEventsWriter::WriteSourceFile(SourceFile* source_file) {
   DebugEvent debug_event;
   debug_event.set_allocated_source_file(source_file);
-  SerializeAndWriteDebugEvent(&debug_event, SOURCE_FILES);
+  return SerializeAndWriteDebugEvent(&debug_event, SOURCE_FILES);
 }
 
-void DebugEventsWriter::WriteStackFrameWithId(
+Status DebugEventsWriter::WriteStackFrameWithId(
     StackFrameWithId* stack_frame_with_id) {
   DebugEvent debug_event;
   debug_event.set_allocated_stack_frame_with_id(stack_frame_with_id);
-  SerializeAndWriteDebugEvent(&debug_event, STACK_FRAMES);
+  return SerializeAndWriteDebugEvent(&debug_event, STACK_FRAMES);
 }
 
-void DebugEventsWriter::WriteGraphOpCreation(
+Status DebugEventsWriter::WriteGraphOpCreation(
     GraphOpCreation* graph_op_creation) {
   DebugEvent debug_event;
   debug_event.set_allocated_graph_op_creation(graph_op_creation);
-  SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
+  return SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
 }
 
-void DebugEventsWriter::WriteDebuggedGraph(DebuggedGraph* debugged_graph) {
+Status DebugEventsWriter::WriteDebuggedGraph(DebuggedGraph* debugged_graph) {
   DebugEvent debug_event;
   debug_event.set_allocated_debugged_graph(debugged_graph);
-  SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
+  return SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
 }
 
-void DebugEventsWriter::WriteExecution(Execution* execution) {
+Status DebugEventsWriter::WriteExecution(Execution* execution) {
   if (circular_buffer_size_ <= 0) {
     // No cyclic-buffer behavior.
     DebugEvent debug_event;
     debug_event.set_allocated_execution(execution);
-    SerializeAndWriteDebugEvent(&debug_event, EXECUTION);
+    return SerializeAndWriteDebugEvent(&debug_event, EXECUTION);
   } else {
     // Circular buffer behavior.
     DebugEvent debug_event;
@@ -234,16 +234,18 @@
     if (execution_buffer_.size() > circular_buffer_size_) {
       execution_buffer_.pop_front();
     }
+    return Status::OK();
   }
 }
 
-void DebugEventsWriter::WriteGraphExecutionTrace(
+Status DebugEventsWriter::WriteGraphExecutionTrace(
     GraphExecutionTrace* graph_execution_trace) {
+  TF_RETURN_IF_ERROR(Init());
   if (circular_buffer_size_ <= 0) {
     // No cyclic-buffer behavior.
     DebugEvent debug_event;
     debug_event.set_allocated_graph_execution_trace(graph_execution_trace);
-    SerializeAndWriteDebugEvent(&debug_event, GRAPH_EXECUTION_TRACES);
+    return SerializeAndWriteDebugEvent(&debug_event, GRAPH_EXECUTION_TRACES);
   } else {
     // Circular buffer behavior.
     DebugEvent debug_event;
@@ -257,15 +259,14 @@
     if (graph_execution_trace_buffer_.size() > circular_buffer_size_) {
       graph_execution_trace_buffer_.pop_front();
     }
+    return Status::OK();
   }
 }
 
-void DebugEventsWriter::WriteGraphExecutionTrace(const string& tfdbg_context_id,
-                                                 const string& device_name,
-                                                 const string& op_name,
-                                                 int32 output_slot,
-                                                 int32 tensor_debug_mode,
-                                                 const Tensor& tensor_value) {
+Status DebugEventsWriter::WriteGraphExecutionTrace(
+    const string& tfdbg_context_id, const string& device_name,
+    const string& op_name, int32 output_slot, int32 tensor_debug_mode,
+    const Tensor& tensor_value) {
   std::unique_ptr<GraphExecutionTrace> trace(new GraphExecutionTrace());
   trace->set_tfdbg_context_id(tfdbg_context_id);
   if (!op_name.empty()) {
@@ -279,7 +280,7 @@
   }
   trace->set_device_name(device_name);
   tensor_value.AsProtoTensorContent(trace->mutable_tensor_proto());
-  WriteGraphExecutionTrace(trace.release());
+  return WriteGraphExecutionTrace(trace.release());
 }
 
 void DebugEventsWriter::WriteSerializedNonExecutionDebugEvent(
@@ -487,8 +488,8 @@
   return Status::OK();
 }
 
-void DebugEventsWriter::SerializeAndWriteDebugEvent(DebugEvent* debug_event,
-                                                    DebugEventFileType type) {
+Status DebugEventsWriter::SerializeAndWriteDebugEvent(DebugEvent* debug_event,
+                                                      DebugEventFileType type) {
   std::unique_ptr<SingleDebugEventFileWriter>* writer = nullptr;
   SelectWriter(type, &writer);
   if (writer != nullptr) {
@@ -497,6 +498,11 @@
     string str;
     debug_event->AppendToString(&str);
     (*writer)->WriteSerializedDebugEvent(str);
+    return Status::OK();
+  } else {
+    return errors::Internal(
+        "Unable to find debug events file writer for DebugEventsFileType ",
+        type);
   }
 }
 
diff --git a/tensorflow/core/util/debug_events_writer.h b/tensorflow/core/util/debug_events_writer.h
index 6d219d7..39835ad 100644
--- a/tensorflow/core/util/debug_events_writer.h
+++ b/tensorflow/core/util/debug_events_writer.h
@@ -119,27 +119,27 @@
   // The four DebugEvent fields below are written _without_ the circular buffer.
   // Source file contents are written to the *.source_files file.
   // Takes ownership of source_file.
-  void WriteSourceFile(SourceFile* source_file);
+  Status WriteSourceFile(SourceFile* source_file);
   // Stack frames are written to the *.code_locations file.
   // Takes ownership of stack_frame_with_id.
-  void WriteStackFrameWithId(StackFrameWithId* stack_frame_with_id);
+  Status WriteStackFrameWithId(StackFrameWithId* stack_frame_with_id);
   // Graph op creation events are written to the *.graphs file.
   // Takes ownership of graph_op_creation.
-  void WriteGraphOpCreation(GraphOpCreation* graph_op_creation);
+  Status WriteGraphOpCreation(GraphOpCreation* graph_op_creation);
   // Debugged graphs are written to the *.graphs file.
   // Takes ownership of debugged_graph.
-  void WriteDebuggedGraph(DebuggedGraph* debugged_graph);
+  Status WriteDebuggedGraph(DebuggedGraph* debugged_graph);
 
   // The two DebugEvent fields below are written to the circular buffer
   // and saved to disk only at the FlushExecutionFiles() call.
   // Execution events (eager execution of an op or a tf.function) are written to
   // the *.execution file.
   // Takes ownership of execution.
-  void WriteExecution(Execution* execution);
+  Status WriteExecution(Execution* execution);
   // Graph execution traces (graph-internal tensor values or their summaries)
   // are written to the *.graph_execution_traces file.
   // Takes ownership of graph_execution_trace.
-  void WriteGraphExecutionTrace(GraphExecutionTrace* graph_execution_trace);
+  Status WriteGraphExecutionTrace(GraphExecutionTrace* graph_execution_trace);
 
   // Write a graph execution trace without using a protocol buffer.
   // Instead, pass the raw values related to the graph execution trace.
@@ -155,11 +155,11 @@
   //   tensor_value: The value of the tensor that describes the tensor(s)
   //     that this trace is concerned with. The semantics of this tensor value
   //     depends on the value of `tensor_debug_mode`.
-  void WriteGraphExecutionTrace(const string& tfdbg_context_id,
-                                const string& device_name,
-                                const string& op_name, int32 output_slot,
-                                int32 tensor_debug_mode,
-                                const Tensor& tensor_value);
+  Status WriteGraphExecutionTrace(const string& tfdbg_context_id,
+                                  const string& device_name,
+                                  const string& op_name, int32 output_slot,
+                                  int32 tensor_debug_mode,
+                                  const Tensor& tensor_value);
 
   // Writes a serialized DebugEvent to one of the debug-events files
   // concerned with the non-execution events: the SOURCE_FILES, STACK_FRAMES
@@ -217,8 +217,8 @@
   // Initialize the TFRecord writer for non-metadata file type.
   Status InitNonMetadataFile(DebugEventFileType type);
 
-  void SerializeAndWriteDebugEvent(DebugEvent* debug_event,
-                                   DebugEventFileType type);
+  Status SerializeAndWriteDebugEvent(DebugEvent* debug_event,
+                                     DebugEventFileType type);
 
   void SelectWriter(DebugEventFileType type,
                     std::unique_ptr<SingleDebugEventFileWriter>** writer);
diff --git a/tensorflow/core/util/debug_events_writer_test.cc b/tensorflow/core/util/debug_events_writer_test.cc
index 66cde55..bd0c731 100644
--- a/tensorflow/core/util/debug_events_writer_test.cc
+++ b/tensorflow/core/util/debug_events_writer_test.cc
@@ -263,7 +263,7 @@
   source_file_1->add_lines("");
   source_file_1->add_lines("print(tf.constant([42.0]))");
   source_file_1->add_lines("");
-  writer->WriteSourceFile(source_file_1);
+  TF_ASSERT_OK(writer->WriteSourceFile(source_file_1));
 
   SourceFile* source_file_2 = new SourceFile();
   source_file_2->set_file_path("/home/tf_programs/train.py");
@@ -271,7 +271,7 @@
   source_file_2->add_lines("import tensorflow.keras as keras");
   source_file_2->add_lines("");
   source_file_2->add_lines("model = keras.Sequential()");
-  writer->WriteSourceFile(source_file_2);
+  TF_ASSERT_OK(writer->WriteSourceFile(source_file_2));
 
   TF_ASSERT_OK(writer->FlushNonExecutionFiles());
   TF_ASSERT_OK(writer->Close());
@@ -336,8 +336,8 @@
   file_line_col->set_func("my_func");
   file_line_col->set_code("  x = x ** 2.0");
 
-  writer->WriteStackFrameWithId(stack_frame_1);
-  writer->WriteStackFrameWithId(stack_frame_2);
+  TF_ASSERT_OK(writer->WriteStackFrameWithId(stack_frame_1));
+  TF_ASSERT_OK(writer->WriteStackFrameWithId(stack_frame_2));
   TF_ASSERT_OK(writer->FlushNonExecutionFiles());
   TF_ASSERT_OK(writer->Close());
 
@@ -382,12 +382,12 @@
   GraphOpCreation* graph_op_creation = new GraphOpCreation();
   graph_op_creation->set_op_type("MatMul");
   graph_op_creation->set_op_name("Dense_1/MatMul");
-  writer->WriteGraphOpCreation(graph_op_creation);
+  TF_ASSERT_OK(writer->WriteGraphOpCreation(graph_op_creation));
 
   DebuggedGraph* debugged_graph = new DebuggedGraph();
   debugged_graph->set_graph_id("deadbeaf");
   debugged_graph->set_graph_name("my_func_graph");
-  writer->WriteDebuggedGraph(debugged_graph);
+  TF_ASSERT_OK(writer->WriteDebuggedGraph(debugged_graph));
 
   TF_ASSERT_OK(writer->FlushNonExecutionFiles());
   TF_ASSERT_OK(writer->Close());
@@ -428,7 +428,7 @@
     SourceFile* source_file = new SourceFile();
     source_file->set_file_path(file_path);
     source_file->set_host_name("localhost.localdomain");
-    writer->WriteSourceFile(source_file);
+    TF_ASSERT_OK(writer->WriteSourceFile(source_file));
   };
   for (size_t i = 0; i < kConcurrentWrites; ++i) {
     thread_pool->Schedule(fn);
@@ -469,7 +469,7 @@
     SourceFile* source_file = new SourceFile();
     source_file->set_file_path(file_path);
     source_file->set_host_name("localhost.localdomain");
-    writer->WriteSourceFile(source_file);
+    TF_ASSERT_OK(writer->WriteSourceFile(source_file));
     TF_ASSERT_OK(writer->FlushNonExecutionFiles());
   };
   for (size_t i = 0; i < kConcurrentWrites; ++i) {
@@ -512,16 +512,16 @@
       source_file->set_file_path(
           strings::Printf("/home/tf_programs/program_%.2d.py", index));
       source_file->set_host_name("localhost.localdomain");
-      writer->WriteSourceFile(source_file);
+      TF_ASSERT_OK(writer->WriteSourceFile(source_file));
     } else if (index % 3 == 1) {
       StackFrameWithId* stack_frame = new StackFrameWithId();
       stack_frame->set_id(strings::Printf("e%.2d", index));
-      writer->WriteStackFrameWithId(stack_frame);
+      TF_ASSERT_OK(writer->WriteStackFrameWithId(stack_frame));
     } else {
       GraphOpCreation* op_creation = new GraphOpCreation();
       op_creation->set_op_type("Log");
       op_creation->set_op_name(strings::Printf("Log_%.2d", index));
-      writer->WriteGraphOpCreation(op_creation);
+      TF_ASSERT_OK(writer->WriteGraphOpCreation(op_creation));
     }
   };
   for (size_t i = 0; i < kConcurrentWrites; ++i) {
@@ -586,7 +586,7 @@
     Execution* execution = new Execution();
     execution->set_op_type("Log");
     execution->add_input_tensor_ids(i);
-    writer->WriteExecution(execution);
+    TF_ASSERT_OK(writer->WriteExecution(execution));
   }
 
   std::vector<DebugEvent> actuals;
@@ -611,7 +611,7 @@
     Execution* execution = new Execution();
     execution->set_op_type("Log");
     execution->add_input_tensor_ids(i);
-    writer->WriteExecution(execution);
+    TF_ASSERT_OK(writer->WriteExecution(execution));
   }
 
   TF_ASSERT_OK(writer->FlushExecutionFiles());
@@ -637,7 +637,7 @@
     Execution* execution = new Execution();
     execution->set_op_type("Abs");
     execution->add_input_tensor_ids(counter.fetch_add(1));
-    writer->WriteExecution(execution);
+    TF_ASSERT_OK(writer->WriteExecution(execution));
   };
   for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
     thread_pool->Schedule(fn);
@@ -682,7 +682,7 @@
   for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
     GraphExecutionTrace* trace = new GraphExecutionTrace();
     trace->set_tfdbg_context_id(strings::Printf("graph_%.2ld", i));
-    writer->WriteGraphExecutionTrace(trace);
+    TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
   }
 
   std::vector<DebugEvent> actuals;
@@ -695,6 +695,31 @@
   TF_ASSERT_OK(writer->Close());
 }
 
+TEST_F(DebugEventsWriterTest, WriteGrahExecutionTraceWithoutPreviousInitCall) {
+  const size_t kCyclicBufferSize = -1;
+  DebugEventsWriter* writer =
+      DebugEventsWriter::GetDebugEventsWriter(dump_root_, kCyclicBufferSize);
+  // NOTE(cais): `writer->Init()` is not called here before
+  // WriteGraphExecutionTrace() is called. This test checks that this is okay
+  // and the `GraphExecutionTrace` gets written correctly even without `Init()`
+  // being called first. This scenario can happen when a TF Graph with tfdbg
+  // debug ops are executed on a remote TF server.
+
+  GraphExecutionTrace* trace = new GraphExecutionTrace();
+  trace->set_tfdbg_context_id(strings::Printf("graph_0"));
+  TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
+  TF_ASSERT_OK(writer->FlushExecutionFiles());
+
+  std::vector<DebugEvent> actuals;
+  ReadDebugEventProtos(writer, DebugEventFileType::GRAPH_EXECUTION_TRACES,
+                       &actuals);
+  EXPECT_EQ(actuals.size(), 1);
+  EXPECT_EQ(actuals[0].graph_execution_trace().tfdbg_context_id(), "graph_0");
+
+  // Close the writer so the files can be safely deleted.
+  TF_ASSERT_OK(writer->Close());
+}
+
 TEST_F(DebugEventsWriterTest, WriteGrahExecutionTraceWithCyclicBufferFlush) {
   const size_t kCyclicBufferSize = 10;
   DebugEventsWriter* writer =
@@ -706,7 +731,7 @@
   for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
     GraphExecutionTrace* trace = new GraphExecutionTrace();
     trace->set_tfdbg_context_id(strings::Printf("graph_%.2ld", i));
-    writer->WriteGraphExecutionTrace(trace);
+    TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
   }
 
   TF_ASSERT_OK(writer->FlushExecutionFiles());
@@ -731,7 +756,7 @@
     GraphExecutionTrace* trace = new GraphExecutionTrace();
     trace->set_tfdbg_context_id(
         strings::Printf("new_graph_%.2ld", counter.fetch_add(1)));
-    writer->WriteGraphExecutionTrace(trace);
+    TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
   };
   for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
     thread_pool->Schedule(fn);
@@ -818,7 +843,7 @@
     Execution* execution = new Execution();
     execution->set_op_type("Log");
     execution->add_input_tensor_ids(i);
-    writer->WriteExecution(execution);
+    TF_ASSERT_OK(writer->WriteExecution(execution));
   }
   TF_ASSERT_OK(writer->FlushExecutionFiles());
 
@@ -834,7 +859,7 @@
   for (size_t i = 0; i < kNumEvents; ++i) {
     GraphExecutionTrace* trace = new GraphExecutionTrace();
     trace->set_tfdbg_context_id(strings::Printf("graph_%.2ld", i));
-    writer->WriteGraphExecutionTrace(trace);
+    TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
   }
   TF_ASSERT_OK(writer->FlushExecutionFiles());
 
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 53aa48b..a90fc2e 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -12059,7 +12059,7 @@
 //
 // value: The cropped area of the image must have an aspect ratio =
 // width / height within this range.
-// If not specified, defaults to {f:0.75  f:1.33}
+// If not specified, defaults to {f:0.75 f:1.33}
 func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
 	return func(m optionalAttr) {
 		m["aspect_ratio_range"] = value
@@ -12070,7 +12070,7 @@
 //
 // value: The cropped area of the image must contain a fraction of the
 // supplied image within this range.
-// If not specified, defaults to {f:0.05  f:1}
+// If not specified, defaults to {f:0.05 f:1}
 func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
 	return func(m optionalAttr) {
 		m["area_range"] = value
@@ -18975,7 +18975,7 @@
 //
 // value: The cropped area of the image must have an aspect ratio =
 // width / height within this range.
-// If not specified, defaults to {f:0.75  f:1.33}
+// If not specified, defaults to {f:0.75 f:1.33}
 func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr {
 	return func(m optionalAttr) {
 		m["aspect_ratio_range"] = value
@@ -18986,7 +18986,7 @@
 //
 // value: The cropped area of the image must contain a fraction of the
 // supplied image within this range.
-// If not specified, defaults to {f:0.05  f:1}
+// If not specified, defaults to {f:0.05 f:1}
 func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
 	return func(m optionalAttr) {
 		m["area_range"] = value
@@ -19390,7 +19390,7 @@
 // ImageSummaryBadColor sets the optional bad_color attribute to value.
 //
 // value: Color to use for pixels with non-finite values.
-// If not specified, defaults to {dtype:DT_UINT8  tensor_shape:{dim:{size:4}}  int_val:255  int_val:0  int_val:0  int_val:255}
+// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
 func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
 	return func(m optionalAttr) {
 		m["bad_color"] = value
@@ -20461,7 +20461,7 @@
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
 func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -21633,7 +21633,7 @@
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22341,7 +22341,7 @@
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func Conv2DDilations(value []int64) Conv2DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22537,7 +22537,7 @@
 // QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22606,7 +22606,7 @@
 // QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22721,7 +22721,7 @@
 // QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22780,7 +22780,7 @@
 // QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value.
 //
 // value: List of dilation values.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -22954,7 +22954,7 @@
 // QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value.
 //
 // value: list of dilation values.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -23331,7 +23331,7 @@
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
 func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -25651,7 +25651,7 @@
 type Conv3DBackpropFilterAttr func(optionalAttr)
 
 // Conv3DBackpropFilterDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
 func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -25714,7 +25714,7 @@
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
 func Conv3DDilations(value []int64) Conv3DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -25965,7 +25965,7 @@
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -26449,7 +26449,7 @@
 // filter element on that dimension. The dimension order is determined by the
 // value of `data_format`, see above for details. Dilations in the batch and
 // depth dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -45537,7 +45537,7 @@
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -47477,7 +47477,7 @@
 type Conv3DBackpropInputAttr func(optionalAttr)
 
 // Conv3DBackpropInputDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1  i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
 func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -47548,7 +47548,7 @@
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
@@ -48537,7 +48537,7 @@
 // element on that dimension. The dimension order is determined by the value of
 // `data_format`, see above for details. Dilations in the batch and depth
 // dimensions must be 1.
-// If not specified, defaults to {i:1  i:1  i:1  i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
 func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr {
 	return func(m optionalAttr) {
 		m["dilations"] = value
diff --git a/tensorflow/go/saved_model.go b/tensorflow/go/saved_model.go
index 7aa1e83..64ae82e 100644
--- a/tensorflow/go/saved_model.go
+++ b/tensorflow/go/saved_model.go
@@ -22,7 +22,7 @@
 	"unsafe"
 
 	"github.com/golang/protobuf/proto"
-	corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"
+	corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
 )
 
 // #include <stdlib.h>
diff --git a/tensorflow/go/signature.go b/tensorflow/go/signature.go
index 8aac0e2..c2db0c7 100644
--- a/tensorflow/go/signature.go
+++ b/tensorflow/go/signature.go
@@ -16,7 +16,7 @@
 
 package tensorflow
 
-import corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"
+import corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
 
 // #include "tensorflow/c/c_api.h"
 import "C"
diff --git a/tensorflow/go/signature_test.go b/tensorflow/go/signature_test.go
index e6927f3..f9fa8427 100644
--- a/tensorflow/go/signature_test.go
+++ b/tensorflow/go/signature_test.go
@@ -20,9 +20,9 @@
 	"fmt"
 	"testing"
 
-	corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"
 	tspb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"
 	typb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"
+	corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
 )
 
 func TestSignatureFromProto(t *testing.T) {
diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c
index f70a600..e6b4789 100644
--- a/tensorflow/lite/c/common.c
+++ b/tensorflow/lite/c/common.c
@@ -79,7 +79,8 @@
 void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
 
 void TfLiteTensorDataFree(TfLiteTensor* t) {
-  if (t->allocation_type == kTfLiteDynamic) {
+  if (t->allocation_type == kTfLiteDynamic ||
+      t->allocation_type == kTfLitePersistentRo) {
     free(t->data.raw);
   }
   t->data.raw = NULL;
@@ -172,7 +173,8 @@
 }
 
 void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
-  if (tensor->allocation_type != kTfLiteDynamic) {
+  if (tensor->allocation_type != kTfLiteDynamic &&
+      tensor->allocation_type != kTfLitePersistentRo) {
     return;
   }
   // TODO(b/145340303): Tensor data should be aligned.
diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h
index 9657c7e..ab150e8 100644
--- a/tensorflow/lite/c/common.h
+++ b/tensorflow/lite/c/common.h
@@ -321,15 +321,23 @@
   void* data;
 } TfLitePtrUnion;
 
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+// Memory allocation strategies.
+//  * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
+//  * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
+//        and available during eval.
+//  * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
+//        only available during eval.
+//  * kTfLiteDynamic: Allocated during eval, or for string tensors.
+//  * kTfLitePersistentRo: Allocated and populated during prepare. This is
+//        useful for tensors that can be computed during prepare and treated
+//        as constant inputs for downstream ops (also in prepare).
 typedef enum TfLiteAllocationType {
   kTfLiteMemNone = 0,
   kTfLiteMmapRo,
   kTfLiteArenaRw,
   kTfLiteArenaRwPersistent,
   kTfLiteDynamic,
+  kTfLitePersistentRo,
 } TfLiteAllocationType;
 
 // The delegates should use zero or positive integers to represent handles.
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 4cebd05..7f4e0e2 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -1183,7 +1183,8 @@
   // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
   if (tensor->allocation_type == kTfLiteArenaRw ||
       tensor->allocation_type == kTfLiteDynamic ||
-      tensor->allocation_type == kTfLiteArenaRwPersistent) {
+      tensor->allocation_type == kTfLiteArenaRwPersistent ||
+      tensor->allocation_type == kTfLitePersistentRo) {
     tensor_resized_since_op_invoke_ |=
         TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
     if (tensor->type != kTfLiteString) {
@@ -1195,14 +1196,16 @@
         return kTfLiteError;
       }
 
-      // Realloc space for kTfLiteDynamic tensors.
+      // Realloc space for heap-allocated tensors.
       TfLiteTensorRealloc(bytesRequired, tensor);
       tensor->bytes = bytesRequired;
     }
     if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
     tensor->dims = new_size;
 
-    if (tensor->allocation_type != kTfLiteDynamic) {
+    // Reset arena-allocated tensors; they will be allocated later.
+    if (tensor->allocation_type == kTfLiteArenaRw ||
+        tensor->allocation_type == kTfLiteArenaRwPersistent) {
       tensor->data.raw = nullptr;
     }
   } else {
diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD
index 9fe8060..d69d220 100644
--- a/tensorflow/lite/delegates/flex/BUILD
+++ b/tensorflow/lite/delegates/flex/BUILD
@@ -26,7 +26,7 @@
             "//tensorflow/core:android_tensorflow_lib_lite",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib_lite",
+            "//tensorflow/core:portable_tensorflow_lib_lite",
         ],
         "//conditions:default": [
             "//tensorflow/c:c_api_internal",
@@ -66,7 +66,7 @@
             "//tensorflow/core:android_tensorflow_lib",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib",
+            "//tensorflow/core:portable_tensorflow_lib",
         ],
         "//conditions:default": [
             "//tensorflow/core:tensorflow",
@@ -103,7 +103,7 @@
             "//tensorflow/core:android_tensorflow_lib_lite",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib_lite",
+            "//tensorflow/core:portable_tensorflow_lib_lite",
         ],
         "//conditions:default": [
             "//tensorflow/core:lib",
@@ -137,7 +137,7 @@
             "//tensorflow/core:android_tensorflow_lib_lite",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib_lite",
+            "//tensorflow/core:portable_tensorflow_lib_lite",
         ],
         "//conditions:default": [
             "//tensorflow/core/common_runtime/eager:context",
@@ -183,7 +183,7 @@
             "//tensorflow/core:android_tensorflow_lib_lite",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib_lite",
+            "//tensorflow/core:portable_tensorflow_lib_lite",
         ],
         "//conditions:default": [
             "//tensorflow/core/common_runtime/eager:context",
@@ -211,7 +211,7 @@
             "//tensorflow/core:android_tensorflow_lib",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib",
+            "//tensorflow/core:portable_tensorflow_lib",
         ],
         "//conditions:default": [
             "//tensorflow/core:tensorflow",
@@ -245,7 +245,7 @@
             "//tensorflow/core:android_tensorflow_lib_lite",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib_lite",
+            "//tensorflow/core:portable_tensorflow_lib_lite",
         ],
         "//conditions:default": [
             "//tensorflow/c:c_api_internal",
diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 099f653..2581232 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -167,7 +167,7 @@
         "metal_delegate.h",
         "metal_delegate_internal.h",
     ],
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     deps = [":metal_delegate"],
 )
 
diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc
index 7ffb560..475eed4 100644
--- a/tensorflow/lite/delegates/gpu/cl/api.cc
+++ b/tensorflow/lite/delegates/gpu/cl/api.cc
@@ -352,10 +352,10 @@
 };
 
 TensorObject TensorToObj(const Tensor& tensor) {
-  if (tensor.StorageType() == TensorStorageType::BUFFER) {
+  if (tensor.GetStorageType() == TensorStorageType::BUFFER) {
     return OpenClBuffer{tensor.GetMemoryPtr()};
   }
-  if (tensor.StorageType() == TensorStorageType::IMAGE_BUFFER) {
+  if (tensor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) {
     return OpenClBuffer{tensor.GetMemoryPtrForWriting()};
   }
   return OpenClTexture{tensor.GetMemoryPtr()};
@@ -516,9 +516,9 @@
   def.dimensions.h = tensor.Height();
   def.dimensions.w = tensor.Width();
   def.dimensions.c = tensor.Channels();
-  def.object_def.data_layout = ToDataLayout(tensor.StorageType());
-  def.object_def.data_type = tensor.DataType();
-  def.object_def.object_type = ToObjectType(tensor.StorageType());
+  def.object_def.data_layout = ToDataLayout(tensor.GetStorageType());
+  def.object_def.data_type = tensor.GetDataType();
+  def.object_def.object_type = ToObjectType(tensor.GetStorageType());
   def.object_def.user_provided = false;
   return def;
 }
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index f01975e..4a52508 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -29,7 +29,7 @@
 namespace {
 
 absl::Status CreateImageBufferFromBuffer(const CLContext& context,
-                                         cl_mem memory, enum DataType data_type,
+                                         cl_mem memory, DataType data_type,
                                          int width, cl_mem* result) {
   cl_image_format format;
   cl_image_desc desc;
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h
index d59ef83..cb7d426 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.h
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.h
@@ -75,8 +75,8 @@
   int4 GetWHSB() const { return int4(shape_.w, shape_.h, Slices(), shape_.b); }
   int4 GetWHDS() const { return int4(shape_.w, shape_.h, shape_.d, Slices()); }
 
-  enum DataType DataType() const { return descriptor_.data_type; }
-  TensorStorageType StorageType() const { return descriptor_.storage_type; }
+  DataType GetDataType() const { return descriptor_.data_type; }
+  TensorStorageType GetStorageType() const { return descriptor_.storage_type; }
 
   // for profiling and memory statistics
   uint64_t GetMemorySizeInBytes() const;
diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc
index 3924f91..bdcf6f6 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.cc
+++ b/tensorflow/lite/delegates/gpu/common/operations.cc
@@ -506,6 +506,14 @@
               attr.appended.c + attr.prepended.c + input.c);
 }
 
+BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr) {
+  return BHWDC(attr.appended.b + attr.prepended.b + input.b,
+               attr.appended.h + attr.prepended.h + input.h,
+               attr.appended.w + attr.prepended.w + input.w,
+               attr.appended.d + attr.prepended.d + input.d,
+               attr.appended.c + attr.prepended.c + input.c);
+}
+
 BHWC CalculateOutputShape(const BHWC& input,
                           const FullyConnectedAttributes& attr) {
   return BHWC(input.b, 1, 1, attr.weights.shape.o);
diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h
index f8bfc77..d0268ee 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.h
+++ b/tensorflow/lite/delegates/gpu/common/operations.h
@@ -431,6 +431,17 @@
 // @return shape of a tensor after Pad operation is applied to the given input.
 BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr);
 
+struct Pad3DAttributes {
+  PaddingContentType type = PaddingContentType::ZEROS;
+
+  BHWDC prepended;
+  BHWDC appended;
+};
+
+// @return shape of a tensor after Pad3D operation is applied to the given
+// input.
+BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr);
+
 struct ConstTensorAttributes {
   Tensor<BHWC, DataType::FLOAT32> tensor;
 };
diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc
index 58da886..4b6727e 100644
--- a/tensorflow/lite/delegates/gpu/delegate.cc
+++ b/tensorflow/lite/delegates/gpu/delegate.cc
@@ -263,12 +263,12 @@
 
     input_refs->clear();
     output_refs->clear();
-    const auto& inputs = graph->inputs();
+    const auto inputs = graph->inputs();
     input_refs->reserve(inputs.size());
     for (const auto& input : inputs) {
       input_refs->push_back(input->tensor.ref);
     }
-    const auto& outputs = graph->outputs();
+    const auto outputs = graph->outputs();
     output_refs->reserve(outputs.size());
     for (const auto& output : outputs) {
       output_refs->push_back(output->tensor.ref);
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 192c787..4db8f3d 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -80,7 +80,7 @@
 ios_unit_test(
     name = "common_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -117,7 +117,7 @@
 ios_unit_test(
     name = "compiled_model_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -205,7 +205,7 @@
 ios_unit_test(
     name = "inference_context_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -235,7 +235,7 @@
         "iphone",
     ],
     infoplists = ["Info.plist"],
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     provisioning_profile = "//tensorflow/lite/delegates/gpu/metal:provisioning_profile.mobileprovision",
     tags = tf_gpu_tests_tags() + [
         "local",
@@ -267,7 +267,7 @@
 
 ios_unit_test(
     name = "ComponentsTests",
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + ["notap"],
     test_host = ":TestApplication",
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index a1052b8..657e9b5 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -71,7 +71,7 @@
 ios_unit_test(
     name = "add_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -109,7 +109,7 @@
 ios_unit_test(
     name = "concat_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -151,7 +151,7 @@
 ios_unit_test(
     name = "conv_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -213,7 +213,7 @@
 ios_unit_test(
     name = "depthwise_conv_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -253,7 +253,7 @@
 ios_unit_test(
     name = "elementwise_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -293,7 +293,7 @@
 ios_unit_test(
     name = "fully_connected_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -332,7 +332,7 @@
 ios_unit_test(
     name = "max_unpooling_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -371,7 +371,7 @@
 ios_unit_test(
     name = "mean_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = [
         "notap",
@@ -450,7 +450,7 @@
 ios_unit_test(
     name = "padding_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -490,7 +490,7 @@
 ios_unit_test(
     name = "pooling_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -530,7 +530,7 @@
 ios_unit_test(
     name = "prelu_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -569,7 +569,7 @@
 ios_unit_test(
     name = "relu_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -608,7 +608,7 @@
 ios_unit_test(
     name = "resize_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -648,7 +648,7 @@
 ios_unit_test(
     name = "reshape_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -687,7 +687,7 @@
 ios_unit_test(
     name = "slice_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -727,7 +727,7 @@
 ios_unit_test(
     name = "softmax_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -764,7 +764,7 @@
 ios_unit_test(
     name = "space_to_depth_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -804,7 +804,7 @@
 ios_unit_test(
     name = "transpose_conv_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
@@ -885,7 +885,7 @@
 ios_unit_test(
     name = "winograd_test",
     testonly = 1,
-    minimum_os_version = "10.0",
+    minimum_os_version = "11.0",
     runner = tflite_ios_lab_runner("IOS_LATEST"),
     tags = tf_gpu_tests_tags() + [
         "notap",
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index e790d42..002c299 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -3769,7 +3769,8 @@
         }
       } else if (reg->builtin_code == kTfLiteBuiltinMaximum ||
                  reg->builtin_code == kTfLiteBuiltinMinimum) {
-        const TfLiteTensor& operand_tensor = context->tensors[input_pos];
+        const TfLiteTensor& operand_tensor =
+            context->tensors[node->inputs->data[input_pos]];
         if (operand_tensor.dims->size == 0) {
           int tensor_index;
 
@@ -3814,7 +3815,8 @@
                   reg->builtin_code == kTfLiteBuiltinSum) &&
                  (input_pos == 1)) {
         // The axis needs, be converted to a tensor if specified as scalar
-        const TfLiteTensor& axis_tensor = context->tensors[1];
+        const TfLiteTensor& axis_tensor =
+            context->tensors[node->inputs->data[input_pos]];
         if (axis_tensor.dims->size == 0) {
           TF_LITE_ENSURE_STATUS(
               builder.AddVectorInt32Operand(axis_tensor.data.i32, 1));
diff --git a/tensorflow/lite/delegates/utils/BUILD b/tensorflow/lite/delegates/utils/BUILD
new file mode 100644
index 0000000..069da16
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/BUILD
@@ -0,0 +1,36 @@
+package(
+    default_visibility = [
+        "//visibility:public",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "simple_delegate",
+    srcs = [
+        "simple_delegate.cc",
+    ],
+    hdrs = [
+        "simple_delegate.h",
+    ],
+    deps = [
+        "//tensorflow/lite:kernel_api",
+        "//tensorflow/lite:minimal_logging",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/delegates:utils",
+        "//tensorflow/lite/kernels/internal:compatibility",
+    ],
+)
+
+cc_test(
+    name = "simple_delegate_test",
+    srcs = ["simple_delegate_test.cc"],
+    deps = [
+        ":simple_delegate",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite:kernel_api",
+        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/kernels:builtin_ops",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
diff --git a/tensorflow/lite/delegates/utils/simple_delegate.cc b/tensorflow/lite/delegates/utils/simple_delegate.cc
new file mode 100644
index 0000000..51736e5
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/simple_delegate.cc
@@ -0,0 +1,140 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/utils/simple_delegate.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/context_util.h"
+#include "tensorflow/lite/delegates/utils.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/minimal_logging.h"
+
+namespace tflite {
+namespace {
+TfLiteRegistration GetDelegateKernelRegistration(
+    SimpleDelegateInterface* delegate) {
+  TfLiteRegistration kernel_registration;
+  kernel_registration.profiling_string = nullptr;
+  kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
+  kernel_registration.custom_name = delegate->name();
+  kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
+    delete reinterpret_cast<SimpleDelegateKernelInterface*>(buffer);
+  };
+  kernel_registration.init = [](TfLiteContext* context, const char* buffer,
+                                size_t length) -> void* {
+    const TfLiteDelegateParams* params =
+        reinterpret_cast<const TfLiteDelegateParams*>(buffer);
+    if (params == nullptr) {
+      TF_LITE_KERNEL_LOG(context, "NULL TfLiteDelegateParams passed.");
+      return nullptr;
+    }
+    auto* delegate =
+        reinterpret_cast<SimpleDelegateInterface*>(params->delegate->data_);
+    std::unique_ptr<SimpleDelegateKernelInterface> delegate_kernel(
+        delegate->CreateDelegateKernelInterface());
+    if (delegate_kernel->Init(context, params) != kTfLiteOk) {
+      return nullptr;
+    }
+    return delegate_kernel.release();
+  };
+  kernel_registration.prepare = [](TfLiteContext* context,
+                                   TfLiteNode* node) -> TfLiteStatus {
+    if (node->user_data == nullptr) {
+      TF_LITE_KERNEL_LOG(context, "Delegate kernel was not initialized");
+      return kTfLiteError;
+    }
+    SimpleDelegateKernelInterface* delegate_kernel =
+        reinterpret_cast<SimpleDelegateKernelInterface*>(node->user_data);
+    return delegate_kernel->Prepare(context, node);
+  };
+  kernel_registration.invoke = [](TfLiteContext* context,
+                                  TfLiteNode* node) -> TfLiteStatus {
+    SimpleDelegateKernelInterface* delegate_kernel =
+        reinterpret_cast<SimpleDelegateKernelInterface*>(node->user_data);
+    TFLITE_DCHECK(delegate_kernel != nullptr);
+    return delegate_kernel->Invoke(context, node);
+  };
+
+  return kernel_registration;
+}
+
+TfLiteStatus DelegatePrepare(TfLiteContext* context,
+                             TfLiteDelegate* base_delegate) {
+  auto* delegate =
+      reinterpret_cast<SimpleDelegateInterface*>(base_delegate->data_);
+  delegates::IsNodeSupportedFn node_supported_fn =
+      [=](TfLiteContext* context, TfLiteNode* node,
+          TfLiteRegistration* registration,
+          std::string* unsupported_details) -> bool {
+    return delegate->IsNodeSupportedByDelegate(registration, node, context);
+  };
+  // TODO(b/149484598): Update to have method that gets all supported nodes.
+  delegates::GraphPartitionHelper helper(context, node_supported_fn);
+  TF_LITE_ENSURE_STATUS(helper.Partition(nullptr));
+
+  const auto delegate_partitions = helper.GetFirstNLargestPartitions();
+
+  // To avoid creating a new TfLiteIntArray and free it later, we reserve one
+  // element to represent TfLiteIntArray.size which is the 1st element of
+  // TfLiteIntArray C struct.
+  std::vector<int> supported_nodes(1);
+  for (const auto partition : delegate_partitions) {
+    auto* nodes = partition->nodes_to_replace;
+    supported_nodes.insert(supported_nodes.end(), nodes->data,
+                           nodes->data + nodes->size);
+  }
+  // Set first element to the number of nodes to replace.
+  supported_nodes[0] = supported_nodes.size() - 1;
+
+  TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
+                  "%s delegate: %d nodes delegated out of %d nodes with "
+                  "%d partitions.\n",
+                  delegate->name(), supported_nodes[0],
+                  helper.num_total_nodes(), delegate_partitions.size());
+  TfLiteRegistration delegate_kernel_registration =
+      GetDelegateKernelRegistration(delegate);
+
+  return context->ReplaceNodeSubsetsWithDelegateKernels(
+      context, delegate_kernel_registration,
+      reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()), base_delegate);
+}
+}  // namespace
+
+TfLiteDelegate* TfLiteDelegateFactory::CreateSimpleDelegate(
+    std::unique_ptr<SimpleDelegateInterface> simple_delegate) {
+  if (simple_delegate == nullptr) {
+    return nullptr;
+  }
+  auto delegate = new TfLiteDelegate();
+  delegate->Prepare = &DelegatePrepare;
+  delegate->flags = kTfLiteDelegateFlagsNone;
+  delegate->CopyFromBufferHandle = nullptr;
+  delegate->CopyToBufferHandle = nullptr;
+  delegate->FreeBufferHandle = nullptr;
+  delegate->data_ = simple_delegate.release();
+  return delegate;
+}
+
+void TfLiteDelegateFactory::DeleteSimpleDelegate(TfLiteDelegate* delegate) {
+  if (!delegate) return;
+  SimpleDelegateInterface* simple_delegate =
+      reinterpret_cast<SimpleDelegateInterface*>(delegate->data_);
+  delete simple_delegate;
+  delete delegate;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/utils/simple_delegate.h b/tensorflow/lite/delegates/utils/simple_delegate.h
new file mode 100644
index 0000000..bf35fbc
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/simple_delegate.h
@@ -0,0 +1,109 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file has utilities that facilitates creating new delegates.
+// - SimpleDelegateKernelInterface: Represents a Kernel which handles a subgraph
+// to be delegated. It has Init/Prepare/Invoke which are going to be called
+// during inference, similar to TFLite Kernels. Delegate owner should implement
+// this interface to build/prepare/invoke the delegated subgraph.
+// - SimpleDelegateInterface:
+// This class wraps TFLiteDelegate and users need to implement the interface and
+// then Call GetFinalizedDelegate() to get TfLiteDelegate* that can be passed to
+// ModifyGraphWithDelegate.
+#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
+#define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
+
+#include <memory>
+
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+
+// Users should inherit from this class and implement the interface below.
+// Each instance represents a single part of the graph (subgraph).
+class SimpleDelegateKernelInterface {
+ public:
+  virtual ~SimpleDelegateKernelInterface() {}
+
+  // Initializes a delegated subgraph.
+  // The nodes in the subgraph are inside TfLiteDelegateParams->nodes_to_replace
+  virtual TfLiteStatus Init(TfLiteContext* context,
+                            const TfLiteDelegateParams* params) = 0;
+
+  // Will be called by the framework. Should handle any needed preparation
+  // for the subgraph e.g. allocating buffers, compiling model.
+  // Returns status, and signalling any errors.
+  virtual TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) = 0;
+
+  // Actual subgraph inference should happen on this call.
+  // Returns status, and signalling any errors.
+  virtual TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) = 0;
+};
+
+// Pure Interface that clients should implement.
+// The Interface represents a delegate capabilities and provide factory
+// for SimpleDelegateKernelInterface
+//
+// Clients should implement the following methods:
+// - IsNodeSupportedByDelegate
+// - name
+// - CreateDelegateKernelInterface
+class SimpleDelegateInterface {
+ public:
+  SimpleDelegateInterface() {}
+
+  virtual ~SimpleDelegateInterface() {}
+
+  // Returns true if 'node' is supported by the delegate. False otherwise.
+  virtual bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
+                                         const TfLiteNode* node,
+                                         TfLiteContext* context) const = 0;
+
+  // Returns a name that identifies the delegate.
+  // This name is used for debugging/logging/profiling.
+  virtual const char* name() const = 0;
+
+  // Returns instance of an object that implements the interface
+  // SimpleDelegateKernelInterface.
+  // An instance of SimpleDelegateKernelInterface represents one subgraph to
+  // be delegated.
+  // Caller takes ownership of the returned object.
+  virtual std::unique_ptr<SimpleDelegateKernelInterface>
+  CreateDelegateKernelInterface() = 0;
+};
+
+// Factory class that provides two static methods
+// CreateSimpleDelegate
+// DeleteSimpleDelegate
+// Which should be used to construct TfLiteDelegate from
+// Simple Delegate and delete TfLiteDelegate and SimpleDelegate give
+// tfLiteDelegate* created from 'CreateSimpleDelegate' method.
+// Users should use these methods to Create and Destroy the delegate.
+class TfLiteDelegateFactory {
+ public:
+  // Creates TfLiteDelegate from the provided SimpleDelegateInterface.
+  // The returned TfLiteDelegate should be deleted using DeleteSimpleDelegate.
+  static TfLiteDelegate* CreateSimpleDelegate(
+      std::unique_ptr<SimpleDelegateInterface> simple_delegate);
+
+  // Deletes 'delegate' the passed pointer must be the one returned
+  // from GetFinalizedDelegate.
+  // This function will destruct the SimpleDelegate object too.
+  static void DeleteSimpleDelegate(TfLiteDelegate* delegate);
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
diff --git a/tensorflow/lite/delegates/utils/simple_delegate_test.cc b/tensorflow/lite/delegates/utils/simple_delegate_test.cc
new file mode 100644
index 0000000..fa6d528
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/simple_delegate_test.cc
@@ -0,0 +1,194 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/utils/simple_delegate.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/builtin_op_kernels.h"
+
+namespace tflite {
+namespace {
+// Delegate options.
+struct TestSimpleDelegateOptions {
+  // Allowed ops to delegate.
+  int allowed_builtin_code;
+  // Report error during init.
+  bool error_during_init = false;
+  // Report error during prepare.
+  bool error_during_prepare = false;
+  // Report error during invoke.
+  bool error_during_invoke = false;
+};
+
+// Dummy delegate kernel.
+class TestSimpleDelegateKernel : public SimpleDelegateKernelInterface {
+ public:
+  explicit TestSimpleDelegateKernel(TestSimpleDelegateOptions options)
+      : options_(options) {}
+
+  TfLiteStatus Init(TfLiteContext* context,
+                    const TfLiteDelegateParams* params) override {
+    return !options_.error_during_init ? kTfLiteOk : kTfLiteError;
+  }
+
+  TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) override {
+    return !options_.error_during_prepare ? kTfLiteOk : kTfLiteError;
+  }
+
+  TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) override {
+    return !options_.error_during_invoke ? kTfLiteOk : kTfLiteError;
+  }
+
+ private:
+  TestSimpleDelegateOptions options_;
+};
+
+// Simple delegate which implements the interface of SimpleDelegateInterface.
+// This holds the Delegate capabilities.
+class TestSimpleDelegate : public SimpleDelegateInterface {
+ public:
+  explicit TestSimpleDelegate(TestSimpleDelegateOptions options)
+      : options_(options) {}
+  bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
+                                 const TfLiteNode* node,
+                                 TfLiteContext* context) const override {
+    return options_.allowed_builtin_code == registration->builtin_code;
+  }
+
+  const char* name() const override { return "TestSimpleDelegate"; }
+
+  std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
+      override {
+    return std::make_unique<TestSimpleDelegateKernel>(options_);
+  }
+
+ private:
+  TestSimpleDelegateOptions options_;
+};
+
+class TestDelegate : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    interpreter_.reset(new Interpreter);
+    interpreter_->AddTensors(5);
+    interpreter_->SetInputs({0, 1});
+    interpreter_->SetOutputs({3, 4});
+    TfLiteQuantizationParams quant;
+    interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+                                               quant);
+    interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+                                               quant);
+    interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
+                                               quant);
+    interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
+                                               quant);
+    interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
+                                               quant);
+    TfLiteRegistration* reg = ops::builtin::Register_ADD();
+    void* builtin_data_1 = malloc(sizeof(int));
+    void* builtin_data_2 = malloc(sizeof(int));
+    void* builtin_data_3 = malloc(sizeof(int));
+    interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, builtin_data_1,
+                                        reg);
+    interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, builtin_data_2,
+                                        reg);
+    interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, builtin_data_3,
+                                        reg);
+  }
+
+  void TearDown() override {
+    interpreter_.reset();
+    TfLiteDelegateFactory::DeleteSimpleDelegate(delegate_);
+  }
+
+ protected:
+  std::unique_ptr<Interpreter> interpreter_;
+  TfLiteDelegate* delegate_ = nullptr;
+};
+
+TEST_F(TestDelegate, BasicDelegate) {
+  TestSimpleDelegateOptions options;
+  options.allowed_builtin_code = kTfLiteBuiltinAdd;
+  delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+      std::make_unique<TestSimpleDelegate>(options));
+  interpreter_->ModifyGraphWithDelegate(delegate_);
+
+  ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+  int node = interpreter_->execution_plan()[0];
+  const auto* node_and_reg = interpreter_->node_and_registration(node);
+  EXPECT_EQ("TestSimpleDelegate", node_and_reg->second.custom_name);
+
+  const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>(
+      node_and_reg->first.builtin_data);
+  ASSERT_EQ(params->nodes_to_replace->size, 3);
+  EXPECT_EQ(params->nodes_to_replace->data[0], 0);
+  EXPECT_EQ(params->nodes_to_replace->data[1], 1);
+  EXPECT_EQ(params->nodes_to_replace->data[2], 2);
+
+  ASSERT_EQ(params->input_tensors->size, 2);
+  EXPECT_EQ(params->input_tensors->data[0], 0);
+  EXPECT_EQ(params->input_tensors->data[1], 1);
+
+  ASSERT_EQ(params->output_tensors->size, 2);
+  EXPECT_EQ(params->output_tensors->data[0], 3);
+  EXPECT_EQ(params->output_tensors->data[1], 4);
+}
+
+TEST_F(TestDelegate, NoNodesToDelegate) {
+  TestSimpleDelegateOptions options;
+  options.allowed_builtin_code = kTfLiteBuiltinSub;
+  delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+      std::make_unique<TestSimpleDelegate>(options));
+  interpreter_->ModifyGraphWithDelegate(delegate_);
+
+  ASSERT_EQ(interpreter_->execution_plan().size(), 3);
+}
+
+TEST_F(TestDelegate, DelegateFailedPrepare) {
+  TestSimpleDelegateOptions options;
+  options.allowed_builtin_code = kTfLiteBuiltinAdd;
+  options.error_during_prepare = true;
+  delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+      std::make_unique<TestSimpleDelegate>(options));
+  ASSERT_EQ(kTfLiteDelegateError,
+            interpreter_->ModifyGraphWithDelegate(delegate_));
+}
+
+TEST_F(TestDelegate, DelegateFailedInvoke) {
+  TestSimpleDelegateOptions options;
+  options.allowed_builtin_code = kTfLiteBuiltinAdd;
+  options.error_during_invoke = true;
+  delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+      std::make_unique<TestSimpleDelegate>(options));
+  ASSERT_EQ(kTfLiteOk, interpreter_->ModifyGraphWithDelegate(delegate_));
+  ASSERT_EQ(kTfLiteError, interpreter_->Invoke());
+}
+
+TEST_F(TestDelegate, DelegateFailedInit) {
+  TestSimpleDelegateOptions options;
+  options.allowed_builtin_code = kTfLiteBuiltinAdd;
+  options.error_during_init = true;
+  delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+      std::make_unique<TestSimpleDelegate>(options));
+  ASSERT_EQ(kTfLiteDelegateError,
+            interpreter_->ModifyGraphWithDelegate(delegate_));
+}
+}  // namespace
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/hexagon/README.md b/tensorflow/lite/experimental/delegates/hexagon/README.md
index 5cf71fd..07f1a92 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/README.md
+++ b/tensorflow/lite/experimental/delegates/hexagon/README.md
@@ -80,6 +80,7 @@
 * L2Normalization (without any activation)
 * Logistic (aka Sigmoid)
 * MaxPool2D (without any activation) (b/129276536)
+* MirrorPad
 * Mul (without any activation) (b/129276536)
 * Neg
 * Pad: Only supports 0 padding (b/139277813)
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD
index ae8ffe2..ff76498 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD
@@ -19,6 +19,7 @@
         "hardswish_builder.cc",
         "l2_normalization_builder.cc",
         "matmul_builder.cc",
+        "mirror_pad_builder.cc",
         "neg_op_builder.cc",
         "op_builder.cc",
         "pad_builder.cc",
@@ -45,6 +46,7 @@
         "hardswish_builder.h",
         "l2_normalization_builder.h",
         "matmul_builder.h",
+        "mirror_pad_builder.h",
         "neg_op_builder.h",
         "op_builder.h",
         "pad_builder.h",
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.cc
new file mode 100644
index 0000000..2a04088
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.cc
@@ -0,0 +1,112 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h"
+
+#include <stdint.h>
+
+#include <limits>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace delegates {
+namespace hexagon {
+TfLiteStatus MirrorPadOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
+                                                  const TfLiteIntArray* outputs,
+                                                  TfLiteContext* context) {
+  static int quant_bound_shape[] = {1, 1, 1, 1};
+  int tensor_id;
+
+  // Input data tensor.
+  tensor_id = inputs->data[0];
+  const auto& input_tensor = context->tensors[tensor_id];
+  AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
+
+  // Padding tensor.
+  // Should be a constant.
+  tensor_id = inputs->data[1];
+  const auto& padding_tensor = context->tensors[tensor_id];
+  if (padding_tensor.dims->size != 2 || padding_tensor.dims->data[0] > 4 ||
+      padding_tensor.dims->data[1] != 2) {
+    TF_LITE_KERNEL_LOG(context, "Invalid padding tensor shape");
+    return kTfLiteError;
+  }
+  paddings_shape_ = {1, 1, 4, 2};
+  std::vector<int> padding_data(8, 0);
+  // Hexagon always expects padding data for each dimension in order {b, h, w,
+  // d}. This start value ensures we pad the non-relevant dimensions with 0.
+  int padding_data_start = 8 - padding_tensor.dims->data[0] * 2;
+  for (int i = 0; i < padding_tensor.dims->data[0] * 2; ++i) {
+    padding_data[padding_data_start + i] = padding_tensor.data.i32[i];
+  }
+  auto* const_padding_node = graph_builder_->AddConstNodeWithData(
+      paddings_shape_.data(), reinterpret_cast<char*>(padding_data.data()),
+      padding_data.size() * sizeof(padding_data[0]));
+  AddInput(TensorID(const_padding_node->GetID(), 0));
+  // Padding type.
+  const TfLiteMirrorPaddingParams* params =
+      reinterpret_cast<const TfLiteMirrorPaddingParams*>(builtin_data_);
+  if (params->mode == kTfLiteMirrorPaddingReflect) {
+    SetPaddingType(NN_PAD_MIRROR_REFLECT);
+  } else if (params->mode == kTfLiteMirrorPaddingSymmetric) {
+    SetPaddingType(NN_PAD_MIRROR_SYMMETRIC);
+  }
+
+  // Min/max values for input tensor.
+  TF_LITE_ENSURE_STATUS(
+      ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_));
+  auto* input_min_const = graph_builder_->AddConstNodeWithData(
+      quant_bound_shape, reinterpret_cast<char*>(&input_min_),
+      sizeof(input_min_));
+  auto* input_max_const = graph_builder_->AddConstNodeWithData(
+      quant_bound_shape, reinterpret_cast<char*>(&input_max_),
+      sizeof(input_max_));
+  AddInput(TensorID(input_min_const->GetID(), 0));
+  AddInput(TensorID(input_max_const->GetID(), 0));
+
+  // Hexagon outputs for this node.
+  int output_batch_size, output_height_size, output_width_size,
+      output_depth_size;
+  GetDims(&output_batch_size, &output_height_size, &output_width_size,
+          &output_depth_size, context->tensors[outputs->data[0]].dims);
+  node_output_ = AddOutput(sizeof(uint8_t), 4,
+                           {output_batch_size, output_height_size,
+                            output_width_size, output_depth_size});
+  AddOutput(sizeof(float), 4, {1, 1, 1, 1});
+  AddOutput(sizeof(float), 4, {1, 1, 1, 1});
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus MirrorPadOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
+                                                 TfLiteContext* context) {
+  // Should be only 1 output.
+  graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
+                                  node_output_.second);
+  return kTfLiteOk;
+}
+
+MirrorPadOpBuilder::~MirrorPadOpBuilder() {}
+
+OpBuilder* CreateMirrorPadBuilder(GraphBuilder* graph_builder, int op_type) {
+  return new MirrorPadOpBuilder(graph_builder, op_type);
+}
+
+}  // namespace hexagon
+}  // namespace delegates
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h
new file mode 100644
index 0000000..6fcb260
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_MIRROR_PAD_BUILDER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_MIRROR_PAD_BUILDER_H_
+
+#include <vector>
+
+#include "tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h"
+
+namespace tflite {
+namespace delegates {
+namespace hexagon {
+
+class MirrorPadOpBuilder : public OpBuilder {
+ public:
+  explicit MirrorPadOpBuilder(GraphBuilder* graph_builder, int op_type)
+      : OpBuilder(graph_builder, op_type) {}
+  TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
+                                const TfLiteIntArray* outputs,
+                                TfLiteContext* context) override;
+
+  TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
+                               TfLiteContext* context) override;
+
+  ~MirrorPadOpBuilder() override;
+
+ private:
+  TensorID node_output_;
+  float input_min_, input_max_;
+  std::vector<int> paddings_shape_;
+};
+
+}  // namespace hexagon
+}  // namespace delegates
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_MIRROR_PAD_BUILDER_H_
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc
index e20127a..c7432e6 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc
@@ -43,6 +43,8 @@
       return CreateReduceBuilder(this, OP_QuantizedSum_8to32);
     case kTfLiteBuiltinPad:
       return CreatePadBuilder(this, OP_QuantizedPad_8);
+    case kTfLiteBuiltinMirrorPad:
+      return CreateMirrorPadBuilder(this, OP_MirrorPad_8);
     case kTfLiteBuiltinFullyConnected:
       return CreateMatMulBuilder(this, OP_QuantizedMatMul_8x8to32);
     case kTfLiteBuiltinAveragePool2d:
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h
index e7236fb..0beb88c 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h
@@ -35,6 +35,7 @@
 OpBuilder* CreateReshapeBuilder(GraphBuilder* graph_builder, int op_type);
 OpBuilder* CreateSoftmaxBuilder(GraphBuilder* graph_builder, int op_type);
 OpBuilder* CreateReduceBuilder(GraphBuilder* graph_builder, int op_type);
+OpBuilder* CreateMirrorPadBuilder(GraphBuilder* graph_builder, int op_type);
 OpBuilder* CreatePadBuilder(GraphBuilder* graph_builder, int op_type);
 OpBuilder* CreateResizeNearestNeighborBuilder(GraphBuilder* graph_builder,
                                               int op_type);
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD
index b1df59c..47a78dc 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD
@@ -30,6 +30,7 @@
         "conv_test.cc",
         "l2_norm_test.cc",
         "matmul_test.cc",
+        "mirror_pad_test.cc",
         "mul_test.cc",
         "neg_test.cc",
         "pad_test.cc",
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mirror_pad_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mirror_pad_test.cc
new file mode 100644
index 0000000..4caf96a
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mirror_pad_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
+
+namespace tflite {
+using testing::ElementsAreArray;
+
+template <typename T>
+class MirrorPadOpModel : public SingleOpModelWithHexagon {
+ public:
+  MirrorPadOpModel(const TensorData& input,
+                   std::initializer_list<int> paddings_shape,
+                   std::initializer_list<int> paddings,
+                   const TensorData& output, const tflite::MirrorPadMode mode) {
+    input_id_ = AddInput(input);
+    padding_matrix_id_ =
+        AddConstInput(TensorType_INT32, paddings, paddings_shape);
+    output_id_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_MIRROR_PAD, BuiltinOptions_MirrorPadOptions,
+                 CreateMirrorPadOptions(builder_, mode).Union());
+    BuildInterpreter({GetShape(input_id_), GetShape(padding_matrix_id_)});
+  }
+
+  int input_tensor_id() { return input_id_; }
+
+  std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
+
+ protected:
+  int input_id_;
+  int padding_matrix_id_;
+  int output_id_;
+};
+
+TEST(MirrorPadTest, EmptyPad_UInt8) {
+  MirrorPadOpModel<uint8_t> model(
+      {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {0, 0, 0, 0},
+      {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_REFLECT);
+  model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(MirrorPadTest, PadBothSides_Symmetric_Int8) {
+  MirrorPadOpModel<int8_t> model({TensorType_INT8, {2, 3}, -1.0, 1.0}, {2, 2},
+                                 {1, 1, 1, 1}, {TensorType_INT8, {}, -1.0, 1.0},
+                                 tflite::MirrorPadMode_SYMMETRIC);
+  model.PopulateTensor<int8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({1, 1, 2, 3, 3, 1, 1, 2, 3, 3,
+                                4, 4, 5, 6, 6, 4, 4, 5, 6, 6}));
+}
+
+TEST(MirrorPadTest, PadBothSides_Reflect_UInt8) {
+  MirrorPadOpModel<uint8_t> model(
+      {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {1, 1, 1, 1},
+      {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_REFLECT);
+  model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({5, 4, 5, 6, 5, 2, 1, 2, 3, 2,
+                                5, 4, 5, 6, 5, 2, 1, 2, 3, 2}));
+}
+
+TEST(MirrorPadTest, PadOneSide_left_Reflect_Int8) {
+  MirrorPadOpModel<int8_t> model({TensorType_INT8, {2, 3}, -1.0, 1.0}, {2, 2},
+                                 {1, 0, 1, 0}, {TensorType_INT8, {}, -1.0, 1.0},
+                                 tflite::MirrorPadMode_REFLECT);
+  model.PopulateTensor<int8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({5, 4, 5, 6, 2, 1, 2, 3, 5, 4, 5, 6}));
+}
+
+TEST(MirrorPadTest, PadOneSide_right_Symmetric_UInt8) {
+  MirrorPadOpModel<uint8_t> model(
+      {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {0, 1, 0, 1},
+      {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_SYMMETRIC);
+  model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({1, 2, 3, 3, 4, 5, 6, 6, 4, 5, 6, 6}));
+}
+
+TEST(MirrorPadTest, Pad_1D_Reflect_Int8) {
+  MirrorPadOpModel<int8_t> model({TensorType_INT8, {3}, -1.0, 1.0}, {1, 2},
+                                 {0, 2}, {TensorType_INT8, {}, -1.0, 1.0},
+                                 tflite::MirrorPadMode_REFLECT);
+  model.PopulateTensor<int8_t>(model.input_tensor_id(), {1, 2, 3});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 2, 1}));
+}
+
+TEST(MirrorPadTest, Pad_1D_Symmetric_UInt8) {
+  MirrorPadOpModel<uint8_t> model({TensorType_UINT8, {3}, -1.0, 1.0}, {1, 2},
+                                  {0, 2}, {TensorType_UINT8, {}, -1.0, 1.0},
+                                  tflite::MirrorPadMode_SYMMETRIC);
+  model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 3, 2}));
+}
+
+TEST(MirrorPadTest, PadBothSides_Reflect_Whole_UInt8) {
+  MirrorPadOpModel<uint8_t> model(
+      {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {1, 1, 2, 2},
+      {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_REFLECT);
+  model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+  model.ApplyDelegateAndInvoke();
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1,
+                                6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}));
+}
+
+}  // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc
index df7d742..d9d1480 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc
@@ -80,6 +80,7 @@
     case kTfLiteBuiltinL2Normalization:
     case kTfLiteBuiltinLogistic:
     case kTfLiteBuiltinMaxPool2d:
+    case kTfLiteBuiltinMirrorPad:
     case kTfLiteBuiltinMul:
     case kTfLiteBuiltinPad:
     case kTfLiteBuiltinQuantize:
@@ -159,6 +160,17 @@
       // causes an unexpected shift in dequantized values.
       return false;
     }
+    case kTfLiteBuiltinMirrorPad: {
+      if (!InputsWithCorrectTypes(
+              node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
+          !IsConstantTensor(GetInput(context, node, 1)))
+        return false;
+      const TfLiteMirrorPaddingParams* params =
+          reinterpret_cast<const TfLiteMirrorPaddingParams*>(
+              node->builtin_data);
+      return params->mode == kTfLiteMirrorPaddingReflect ||
+             params->mode == kTfLiteMirrorPaddingSymmetric;
+    }
     case kTfLiteBuiltinPad: {
       // TODO(b/139277813): Currently we only support padding with the default
       // of 0. Add support for user-defined constant if required.
diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple
index faa3f12..8e7b32e 100644
--- a/tensorflow/lite/experimental/ios/BUILD.apple
+++ b/tensorflow/lite/experimental/ios/BUILD.apple
@@ -22,13 +22,6 @@
     """,
 )
 
-TFL_LIBRARY_HDRS = [
-    "//tensorflow/lite/delegates/gpu:metal_delegate.h",
-    "//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h",
-    "//tensorflow/lite/c:c_api.h",
-    "//tensorflow/lite/c:common.h",
-]
-
 TFL_FRAMEWORK_HDRS = [
     "//tensorflow/lite/delegates/gpu:metal_delegate.h",
     ":coreml_delegate.h",
@@ -43,19 +36,6 @@
     bundle_name = "TensorFlowLiteC",
     minimum_os_version = TFL_MINIMUM_OS_VERSION,
     deps = [
-        ":TensorFlowLiteC",
-    ],
-)
-
-objc_library(
-    name = "TensorFlowLiteC",
-    hdrs = TFL_LIBRARY_HDRS,
-    module_name = "TensorFlowLiteC",
-    weak_sdk_frameworks = [
-        "Metal",
-        "CoreML",
-    ],
-    deps = [
         ":tensorflow_lite_c",
     ],
 )
@@ -78,20 +58,22 @@
     ],
 )
 
-# Using this intermediate target is a workaround for a bug in bazel build rules
-# involving mixed objc_library & cc_library deps mentioned in (b/74809458).
-# When these dependencies are declared directly under the "TensorFlowLiteC"
-# target above, the resulting static library incorrectly contains duplicate
-# symbols from some ObjC code in the transitive dependencies.
-#
-# When a new dependency should be added to the TensorFlowLiteC framework, the
-# dependency should be added under this target instead.
-# When a new header file needs to be exposed, the header should be added to the
-# TFL_LIBRARY_HDRS list above.
 cc_library(
     name = "tensorflow_lite_c",
-    hdrs = TFL_LIBRARY_HDRS,
-    tags = ["nobuilder"],
+    hdrs = [
+        "//tensorflow/lite/c:c_api.h",
+        "//tensorflow/lite/c:common.h",
+        "//tensorflow/lite/delegates/gpu:metal_delegate.h",
+        "//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h",
+    ],
+    linkopts = [
+        "-Wl,-weak_framework,CoreML",
+        "-Wl,-weak_framework,Metal",
+    ],
+    tags = [
+        "nobuilder",
+        "swift_module=TensorFlowLiteC",
+    ],
     deps = [
         "//tensorflow/lite/c:c_api",
         "//tensorflow/lite/delegates/gpu:metal_delegate",
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
index b19ef2e..bced23e 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
@@ -231,6 +231,26 @@
     return container.getDataType();
   }
 
+  /**
+   * Gets the image width.
+   *
+   * @throws IllegalStateException if the TensorImage never loads data.
+   * @throws IllegalArgumentException if the container data is corrupted.
+   */
+  public int getWidth() {
+    return container.getWidth();
+  }
+
+  /**
+   * Gets the image height.
+   *
+   * @throws IllegalStateException if the TensorImage never loads data.
+   * @throws IllegalArgumentException if the container data is corrupted.
+   */
+  public int getHeight() {
+    return container.getHeight();
+  }
+
   // Requires tensor shape [h, w, 3] or [1, h, w, 3].
   static void checkImageTensorShape(int[] shape) {
     SupportPreconditions.checkArgument(
@@ -273,6 +293,41 @@
       isBufferUpdated = true;
     }
 
+    int getWidth() {
+      SupportPreconditions.checkState(
+          isBitmapUpdated || isBufferUpdated,
+          "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
+      if (isBitmapUpdated) {
+        return bitmapImage.getWidth();
+      }
+      return getBufferDimensionSize(-2);
+    }
+
+    int getHeight() {
+      SupportPreconditions.checkState(
+          isBitmapUpdated || isBufferUpdated,
+          "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
+      if (isBitmapUpdated) {
+        return bitmapImage.getHeight();
+      }
+      return getBufferDimensionSize(-3);
+    }
+
+    // Internal helper method to get the size of one dimension in the shape of the `bufferImage`.
+    // Requires `isBufferUpdated` is true.
+    // Throws `IllegalArgumentException` if data is corrupted.
+    private int getBufferDimensionSize(int dim) {
+      int[] shape = bufferImage.getShape();
+      // The defensive check is needed because bufferImage might be invalidly changed by user
+      // (a.k.a internal data is corrupted)
+      TensorImage.checkImageTensorShape(shape);
+      dim = dim % shape.length;
+      if (dim < 0) {
+        dim += shape.length;
+      }
+      return shape[dim];
+    }
+
     public DataType getDataType() {
       return dataType;
     }
@@ -284,7 +339,8 @@
         return bitmapImage;
       }
       if (!isBufferUpdated) {
-        throw new IllegalStateException("Both buffer and bitmap data are obsolete.");
+        throw new IllegalStateException(
+            "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
       }
       if (bufferImage.getDataType() != DataType.UINT8) {
         throw new IllegalStateException(
@@ -310,7 +366,8 @@
         return bufferImage;
       }
       SupportPreconditions.checkArgument(
-          isBitmapUpdated, "Both buffer and bitmap data are obsolete.");
+          isBitmapUpdated,
+          "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
       int requiredFlatSize = bitmapImage.getWidth() * bitmapImage.getHeight() * 3;
       if (bufferImage == null
           || (!bufferImage.isDynamic() && bufferImage.getFlatSize() != requiredFlatSize)) {
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
index 16622a2..fa05be3 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
@@ -379,13 +379,13 @@
 
     // Check if the new shape is the same as current shape.
     int newFlatSize = computeFlatSize(shape);
+    this.shape = shape.clone();
     if (flatSize == newFlatSize) {
       return;
     }
 
     // Update to the new shape.
     flatSize = newFlatSize;
-    this.shape = shape.clone();
     buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
     buffer.order(ByteOrder.nativeOrder());
   }
diff --git a/tensorflow/lite/experimental/support/metadata/java/BUILD b/tensorflow/lite/experimental/support/metadata/java/BUILD
index f1cd617..82b6e98 100644
--- a/tensorflow/lite/experimental/support/metadata/java/BUILD
+++ b/tensorflow/lite/experimental/support/metadata/java/BUILD
@@ -25,6 +25,10 @@
     name = "tensorflow-lite-support-metadata-lib",
     srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
     javacopts = JAVACOPTS,
+    resource_jars = [
+        "//tensorflow/lite/experimental/support/metadata:libmetadata_schema_java.jar",
+        "//tensorflow/lite/experimental/support/metadata:libschema_fbs_java.jar",
+    ],
     deps = [
         "//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
         "//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
diff --git a/tensorflow/lite/experimental/swift/BUILD.apple b/tensorflow/lite/experimental/swift/BUILD.apple
index 2ce8428..50130fc 100644
--- a/tensorflow/lite/experimental/swift/BUILD.apple
+++ b/tensorflow/lite/experimental/swift/BUILD.apple
@@ -13,11 +13,15 @@
 swift_library(
     name = "TensorFlowLite",
     srcs = glob(["Sources/*.swift"]),
+    linkopts = [
+        "-Wl,-weak_framework,CoreML",
+        "-Wl,-weak_framework,Metal",
+    ],
     module_name = "TensorFlowLite",
     tags = TFL_DEFAULT_TAGS,
     visibility = ios_visibility_whitelist(),
     deps = [
-        "//tensorflow/lite/experimental/ios:TensorFlowLiteC",
+        "//tensorflow/lite/experimental/ios:tensorflow_lite_c",
     ],
 )
 
diff --git a/tensorflow/lite/g3doc/convert/1x_compatibility.md b/tensorflow/lite/g3doc/convert/1x_compatibility.md
index adb2af4..9f9f277 100644
--- a/tensorflow/lite/g3doc/convert/1x_compatibility.md
+++ b/tensorflow/lite/g3doc/convert/1x_compatibility.md
@@ -1,30 +1,32 @@
-# TensorFlow 1.x compatibility
+# TensorFlow 1.x Compatibility <a name="differences"></a>
 
-The `tf.lite.TFLiteConverter` was updated between TensorFlow 1.X and 2.0. This
-document explains the differences between the 1.X and 2.0 versions of the
-converter, and provides information about how to use the 1.X version if
-required.
+The `tf.lite.TFLiteConverter` Python API was updated between TensorFlow 1.x and
+2.x. This document explains the differences between the two versions, and
+provides information about how to use the 1.x version if required.
 
-## Summary of changes in Python API between 1.X and 2.0 <a name="differences"></a>
-
-The following section summarizes the changes in the Python API from 1.X to 2.0.
 If any of the changes raise concerns, please file a
-[GitHub issue](https://github.com/tensorflow/tensorflow/issues).
+[GitHub Issue](https://github.com/tensorflow/tensorflow/issues).
 
-### Formats supported by `TFLiteConverter`
+Note: We highly recommend that you
+[migrate your TensorFlow 1.x code to TensorFlow 2.x code](https://www.tensorflow.org/guide/migrate)
+.
 
-The 2.0 version of the converter supports SavedModel and Keras model files
-generated in both 1.X and 2.0. However, the conversion process no longer
-supports "frozen graph" `GraphDef` files generated in 1.X.
+## Model formats
 
-#### Converting frozen graphs
+#### SavedModel and Keras
 
-Users who want to convert frozen graph `GraphDef` files (`.pb` files) to
-TensorFlow Lite should use `tf.compat.v1.lite.TFLiteConverter`.
+The `tf.lite.TFLiteConverter` API supports SavedModel and Keras HDF5 files
+generated in both TensorFlow 1.x and 2.x.
 
-The following snippet shows a frozen graph file being converted:
+#### Frozen Graph
+
+Note: TensorFlow 2.x no longer supports the generation of frozen graph models.
+
+The `tf.compat.v1.lite.TFLiteConverter` API supports frozen graph models
+generated in TensorFlow 1.x, as shown below:
 
 ```python
+import tensorflow as tf
 # Path to the frozen graph file
 graph_def_file = 'frozen_graph.pb'
 # A list of the names of the model's input tensors
@@ -32,70 +34,68 @@
 # A list of the names of the model's output tensors
 output_arrays = ['output_name']
 # Load and convert the frozen graph
-converter = lite.TFLiteConverter.from_frozen_graph(
+converter = tf.lite.TFLiteConverter.from_frozen_graph(
   graph_def_file, input_arrays, output_arrays)
 tflite_model = converter.convert()
 # Write the converted model to disk
 open("converted_model.tflite", "wb").write(tflite_model)
 ```
 
-### Quantization-aware training
+## Converter attributes
 
-The following attributes and methods associated with
-[quantization-aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)
-have been removed from `TFLiteConverter` in TensorFlow 2.0:
+#### Renamed attributes
 
-*   `inference_type`
-*   `inference_input_type`
-*   `quantized_input_stats`
-*   `default_ranges_stats`
-*   `reorder_across_fake_quant`
-*   `change_concat_input_ranges`
-*   `post_training_quantize` - Deprecated in the 1.X API
-*   `get_input_arrays()`
+The following 1.x attribute has been renamed in 2.x.
 
-The rewriter function that supports quantization-aware training does not support
-models generated by TensorFlow 2.0. Additionally, TensorFlow Lite’s quantization
-API is being reworked and streamlined in a direction that supports
-quantization-aware training through the Keras API. These attributes will be
-removed in the 2.0 API until the new quantization API is launched. Users who
-want to convert models generated by the rewriter function can use
-`tf.compat.v1.lite.TFLiteConverter`.
+*   `target_ops` has been renamed to `target_spec.supported_ops` - In 2.x, in
+    line with future additions to the optimization framework, it has become an
+    attribute of `TargetSpec` and has been renamed to `supported_ops`.
 
-### Changes to `TFLiteConverter` attributes
+#### Unsupported attributes
 
-The `target_ops` attribute has become an attribute of `TargetSpec` and renamed
-to `supported_ops` in line with future additions to the optimization framework.
+The following 1.x attributes have been removed in 2.x.
 
-Additionally, the following attributes have been removed:
-
-*   `drop_control_dependency` (default: `True`)
-*   _Graph visualization_ - The recommended approach for visualizing a
-    TensorFlow Lite graph in TensorFlow 2.0 will be to use
-    [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py).
-    Unlike GraphViz, it enables users to visualize the graph after post training
-    quantization has occurred. The following attributes related to graph
-    visualization will be removed:
+*   _Quantization_ - In 2.x,
+    [quantize aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training)
+    is supported through the Keras API and
+    [post training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)
+    uses fewer streamlined converter flags. Thus, the following attributes and
+    methods related to quantization have been removed:
+    *   `inference_type`
+    *   `quantized_input_stats`
+    *   `post_training_quantize`
+    *   `default_ranges_stats`
+    *   `reorder_across_fake_quant`
+    *   `change_concat_input_ranges`
+    *   `get_input_arrays()`
+*   _Visualization_ - In 2.x, the recommended approach for visualizing a
+    TensorFlow Lite graph is to use
+    [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py)
+    . Unlike GraphViz, it enables users to visualize the graph after post
+    training quantization has occurred. Thus, the following attributes related
+    to graph visualization have been removed:
     *   `output_format`
     *   `dump_graphviz_dir`
     *   `dump_graphviz_video`
+*   _Frozen graph_ - In 2.x, the frozen graph model format has been removed.
+    Thus, the following attribute related to frozen graphs has been removed:
+    *   `drop_control_dependency`
 
-### General API changes
+## Unsupported APIs
 
-The following section explains several significant API changes between
-TensorFlow 1.X and 2.0.
+The following section explains several significant features in 1.x that have
+been removed in 2.x.
 
-#### Conversion methods
+#### Conversion APIs
 
-The following methods that were previously deprecated in 1.X will no longer be
-exported in 2.0:
+The following methods were deprecated in 1.x and have been removed in 2.x:
 
 *   `lite.toco_convert`
 *   `lite.TocoConverter`
 
-#### `lite.constants`
+#### `lite.constants` API
 
-The `lite.constants` API was removed in 2.0 in order to decrease duplication
+The `lite.constants` API was removed in 2.x in order to decrease duplication
 between TensorFlow and TensorFlow Lite. The following list maps the
 `lite.constant` type to the TensorFlow type:
 
@@ -106,12 +106,15 @@
 *   `lite.constants.STRING`: `tf.string`
 *   `lite.constants.QUANTIZED_UINT8`: `tf.uint8`
 
-Additionally, `lite.constants.TFLITE` and `lite.constants.GRAPHVIZ_DOT` were
-removed due to the deprecation of the `output_format` flag in `TFLiteConverter`.
+Additionally, the deprecation of the `output_format` flag in `TFLiteConverter`
+led to the removal of the following constants:
 
-#### `lite.OpHint`
+*   `lite.constants.TFLITE`
+*   `lite.constants.GRAPHVIZ_DOT`
 
-The `OpHint` API is currently not available in 2.0 due to an incompatibility
-with the 2.0 APIs. This API enables conversion of LSTM based models. Support for
-LSTMs in 2.0 is being investigated. All related `lite.experimental` APIs have
-been removed due to this issue.
+#### `lite.OpHint` API
+
+The `OpHint` API is currently unsupported due to an incompatibility with the 2.x
+APIs. This API enables conversion of LSTM based models. Support for LSTMs in 2.x
+is being investigated. All related `lite.experimental` APIs have been removed
+due to this issue.
diff --git a/tensorflow/lite/g3doc/performance/hexagon_delegate.md b/tensorflow/lite/g3doc/performance/hexagon_delegate.md
index 51af598..60fe946 100644
--- a/tensorflow/lite/g3doc/performance/hexagon_delegate.md
+++ b/tensorflow/lite/g3doc/performance/hexagon_delegate.md
@@ -259,43 +259,7 @@
     *   This is tentatively planned for a future release, though there is no
         concrete timeline.
 *   Which ops are supported by the delegate?
-    *   Initial list of supported ops:
-        *   Add
-        *   ArgMax
-        *   ArgMin
-        *   AveragePool2D (without any activation)
-        *   Concat
-        *   Conv2D with following constraints:
-            *   stride width/height <= 3
-        *   DepthToSpace
-        *   DepthwiseConv2D with following constraints:
-            *   Filter width == 3
-            *   depth_multiplier == 1
-            *   dilation only supported when stride == 1
-            *   Otherwise, stride height/width <= 3
-        *   FullyConnected (without any activation)
-        *   Hardswish
-        *   L2Normalization (without any activation)
-        *   Logistic (aka Sigmoid)
-        *   MaxPool2D (without any activation)
-        *   Mul (without any activation)
-        *   Neg
-        *   Pad: Only supports 0 padding
-        *   Relu
-        *   Relu6
-        *   Reshape
-        *   Resize Bilinear with following constraints:
-            *   Requested size <= 65
-        *   Resize Nearest Neighbor
-        *   SoftMax
-        *   SpaceToDepth
-        *   Split
-        *   Sub
-        *   Tanh
-        *   Transpose
-        *   TransposeConv2D with following constraints:
-            *   stride height/width <= 3
-            *   dilation height/width == 1
+    *   See the current list of [supported ops and constraints](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/delegates/hexagon/README.md)
 *   How can I tell that the model is using the DSP when I enable the delegate?
     *   Two log messages will be printed when you enable the delegate - one to
         indicate if the delegate was created and another to indicate how many
diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md
index 194d102..a526be7 100644
--- a/tensorflow/lite/g3doc/performance/post_training_quantization.md
+++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md
@@ -4,51 +4,44 @@
 while also improving CPU and hardware accelerator latency, with little
 degradation in model accuracy. You can perform these techniques using an
 already-trained float TensorFlow model when you convert it to TensorFlow Lite
-format.
+format using the [TensorFlow Lite Converter](../convert/).
 
 Note: The procedures on this page require TensorFlow 1.15 or higher.
 
-
-### Optimization options
+### Optimization Methods
 
 There are several post-training quantization options to choose from. Here is a
 summary table of the choices and the benefits they provide:
 
-| Technique                 | Benefits                  | Hardware            |
-| ------------------------- | ------------------------- | ------------------- |
-| Dynamic range             | 4x smaller, 2-3x speedup, | CPU                 |
-: quantization              : accuracy                  :                     :
-| Full integer quantization | 4x smaller, 3x+ speedup   | CPU, Edge TPU, etc. |
-| Float16 quantization      | 2x smaller, potential GPU | CPU/GPU             |
-:                           : acceleration              :                     :
+| Technique            | Benefits                  | Hardware         |
+| -------------------- | ------------------------- | ---------------- |
+| Dynamic range        | 4x smaller, 2-3x speedup  | CPU              |
+: quantization         :                           :                  :
+| Full integer         | 4x smaller, 3x+ speedup   | CPU, Edge TPU,   |
+: quantization         :                           : Microcontrollers :
+| Float16 quantization | 2x smaller, potential GPU | CPU, GPU         |
+:                      : acceleration              :                  :
 
 This decision tree can help determine which post-training quantization method is
 best for your use case:
 
 ![post-training optimization options](images/optimization.jpg)
 
-Alternatively, you might achieve higher accuracy if you perform
-[quantization-aware training](
-https://github.com/tensorflow/tensorflow/tree/r1.14/tensorflow/contrib/quantize).
-However, doing so requires some model modifications to add fake quantization
-nodes, whereas the post-training quantization techniques on this page use an
-existing pre-trained model.
-
 ### Dynamic range quantization
 
 The simplest form of post-training quantization statically quantizes only the
-weights from floating point to 8-bits of precision. This technique is enabled as
-an option in the [TensorFlow Lite converter](../convert/):
+weights from floating point to integer, which has 8-bits of precision:
 
-```
+<pre>
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-converter.optimizations = [tf.lite.Optimize.DEFAULT]
+<b>converter.optimizations = [tf.lite.Optimize.DEFAULT]</b>
 tflite_quant_model = converter.convert()
-```
+</pre>
 
-At inference, weights are converted from 8-bits of precision to floating point and
-computed using floating-point kernels. This conversion is done once and cached to reduce latency.
+At inference, weights are converted from 8-bits of precision to floating point
+and computed using floating-point kernels. This conversion is done once and
+cached to reduce latency.
 
 To further improve latency, "dynamic-range" operators dynamically quantize
 activations based on their range to 8-bits and perform computations with 8-bit
@@ -58,89 +51,105 @@
 fixed-point computation. Dynamic-range ops are available for the most
 compute-intensive operators in a network:
 
-*  [tf.contrib.layers.fully_connected](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/fully_connected)
-*  [tf.nn.conv2d](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d)
-*  [tf.nn.embedding_lookup](https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup)
-*  [BasicRNN](https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicRNNCell)
-*  [tf.nn.bidirectional_dynamic_rnn for BasicRNNCell type](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
-*  [tf.nn.dynamic_rnn for LSTM and BasicRNN Cell types](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
+*   `tf.keras.layers.Dense`
+*   `tf.keras.layers.Conv2D`
+*   `tf.keras.layers.LSTM`
+*   `tf.nn.embedding_lookup`
+*   `tf.compat.v1.nn.rnn_cell.BasicRNNCell`
+*   `tf.compat.v1.nn.bidirectional_dynamic_rnn`
+*   `tf.compat.v1.nn.dynamic_rnn`
 
-
-### Full integer quantization of weights and activations
+### Full integer quantization
 
 You can get further latency improvements, reductions in peak memory usage, and
-access to integer only hardware accelerators by making sure all model math is
-quantized.
+access to integer only hardware devices or accelerators by making sure all model
+math is integer quantized.
 
 To do this, you need to measure the dynamic range of activations and inputs by
-supplying a representative data set. You can simply create an input data
-generator and provide it to our converter. For example:
+supplying sample input data to the converter. Refer to the
+`representative_dataset_gen()` function used in the following code.
 
-```
+#### Integer with float fallback (using default float input/output)
+
+In order to fully integer quantize a model, but use float operators when they
+don't have an integer implementation (to ensure conversion occurs smoothly), use
+the following steps:
+
+<pre>
 import tensorflow as tf
-
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+<b>converter.optimizations = [tf.lite.Optimize.DEFAULT]
 def representative_dataset_gen():
   for _ in range(num_calibration_steps):
     # Get sample input data as a numpy array in a method of your choosing.
     yield [input]
-
-converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-converter.optimizations = [tf.lite.Optimize.DEFAULT]
-converter.representative_dataset = representative_dataset_gen
+converter.representative_dataset = representative_dataset_gen</b>
 tflite_quant_model = converter.convert()
-```
+</pre>
 
-The resulting model should be fully quantized, but any
-ops that do not have quantized implementations are left in
-floating point. This allows conversion to occur smoothly, but the model won't be
-compatible with accelerators that require full integer quantization.
+Note: This won't be compatible with integer only devices (such as 8-bit
+microcontrollers) and accelerators (such as the Coral Edge TPU). For convenience
+during inference, the input and output still remain float in order to have the
+same interface as the original float only model.
 
-Additionally, the model still uses float input and output for convenience.
+#### Integer only
 
-To ensure compatibility with some accelerators (such as the Coral Edge TPU), you
-can enforce full integer quantization for all ops and use integer input and
-output by adding the following lines before you convert:
+*This is a common use case for
+[TensorFlow Lite for Microcontrollers](https://www.tensorflow.org/lite/microcontrollers)
+and [Coral Edge TPUs](https://coral.ai/).*
 
-```
-converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
-converter.inference_input_type = tf.uint8
-converter.inference_output_type = tf.uint8
-```
+Additionally, to ensure compatibility with integer only devices (such as 8-bit
+microcontrollers) and accelerators (such as the Coral Edge TPU), you can enforce
+full integer quantization for all ops including the input and output, by using
+the following steps:
 
-The first line makes the converter throw an error if it encounters an operation
-it cannot currently quantize.
-
-Note: `target_spec.supported_ops` was previously `target_ops` in the Python API.
-
-
-### Float16 quantization of weights
-
-You can reduce the size of a floating point model by quantizing the weights to
-float16, the IEEE standard for 16-bit floating point numbers. The advantages of
-this quantization are as follows:
-
--   reduce model size by up to half (since all weights are now half the original
-    size)
--   minimal loss in accuracy
--   some delegates (e.g. the GPU delegate) can operate directly on float16 data,
-    which results in faster execution than float32 computations.
-
-This quantization may not be a good choice if you need maximum performance (a
-quantization to fixed point math would be better in that case). To enable
-float16 quantization of weights, specify "DEFAULT" optimization as above and
-then specify that float16 is in supported types for the target_spec:
-
-```
+<pre>
 import tensorflow as tf
 converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
 converter.optimizations = [tf.lite.Optimize.DEFAULT]
-converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]
+def representative_dataset_gen():
+  for _ in range(num_calibration_steps):
+    # Get sample input data as a numpy array in a method of your choosing.
+    yield [input]
+converter.representative_dataset = representative_dataset_gen
+<b>converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]</b>
+<b>converter.inference_input_type = tf.int8</b>  # or tf.uint8
+<b>converter.inference_output_type = tf.int8</b>  # or tf.uint8
 tflite_quant_model = converter.convert()
-```
+</pre>
 
-By default, a float16 quantized model will "dequantize" the weights values to
-float32 when run on the CPU. The GPU delegate will not perform this
-dequantization, since it can operate on float16 data.
+Note: The converter will throw an error if it encounters an operation it cannot
+currently quantize.
+
+### Float16 quantization
+
+You can reduce the size of a floating point model by quantizing the weights to
+float16, the IEEE standard for 16-bit floating point numbers. To enable float16
+quantization of weights, use the following steps:
+
+<pre>
+import tensorflow as tf
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+<b>converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]</b>
+tflite_quant_model = converter.convert()
+</pre>
+
+The advantages of this quantization are as follows:
+
+*   Reduce model size by up to half (since all weights are now half the original
+    size).
+*   Minimal loss in accuracy.
+*   Supports some delegates (e.g. the GPU delegate) can operate directly on
+    float16 data, which results in faster execution than float32 computations.
+
+The disadvantages of this quantization are as follows:
+
+*   Not a good choice for maximum performance (a quantization to fixed point
+    math would be better in that case).
+*   By default, a float16 quantized model will "dequantize" the weights values
+    to float32 when run on the CPU. (Note that the GPU delegate will not perform
+    this dequantization, since it can operate on float16 data.)
 
 ### Model accuracy
 
@@ -152,13 +161,18 @@
 within acceptable limits. There is a tool to evaluate
 [TensorFlow Lite model accuracy](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/accuracy/ilsvrc/README.md){:.external}.
 
-If the accuracy drop is too high, consider using
-[quantization aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize){:.external}.
+Alternatively, if the accuracy drop is too high, consider using
+[quantization aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training)
+. However, doing so requires modifications during model training to add fake
+quantization nodes, whereas the post-training quantization techniques on this
+page use an existing pre-trained model.
 
 ### Representation for quantized tensors
 
 8-bit quantization approximates floating point values using the following
-formula. `real_value = (int8_value - zero_point) * scale`.
+formula.
+
+$$real\_value = (int8\_value - zero\_point) \times scale$$
 
 The representation has two main parts:
 
diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb
index ee46795..464a5d1 100644
--- a/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb
+++ b/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb
@@ -49,7 +49,7 @@
       "metadata": {
         "colab_type": "text",
         "id": "nDABAblytltI"
-      },
+      }, 
       "source": [
         "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
         "  \u003ctd\u003e\n",
diff --git a/tensorflow/lite/java/ovic/BUILD b/tensorflow/lite/java/ovic/BUILD
index 947fbee..e64bd30 100644
--- a/tensorflow/lite/java/ovic/BUILD
+++ b/tensorflow/lite/java/ovic/BUILD
@@ -58,7 +58,6 @@
     deps = [
         "//tensorflow/lite/java:tensorflowlite",
         "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
-        "@org_checkerframework_qual",
     ],
 )
 
@@ -75,7 +74,6 @@
         "//tensorflow/lite/java:tensorflowlite_java",
         "//tensorflow/lite/java/src/main/native",
         "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
-        "@org_checkerframework_qual",
     ],
 )
 
@@ -114,7 +112,6 @@
     deps = [
         "//tensorflow/lite/java:tensorflowlite",
         "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
-        "@org_checkerframework_qual",
     ],
 )
 
@@ -131,6 +128,5 @@
         "//tensorflow/lite/java:tensorflowlite_java",
         "//tensorflow/lite/java/src/main/native",
         "//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
-        "@org_checkerframework_qual",
     ],
 )
diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h
index ad068dd..5793b08 100644
--- a/tensorflow/lite/kernels/kernel_util.h
+++ b/tensorflow/lite/kernels/kernel_util.h
@@ -87,6 +87,10 @@
 }
 
 // Determines whether tensor is constant.
+// TODO(b/138199592): Introduce new query which checks for constant OR
+// persistent-read-only, which would be useful for most tensor kernels that
+// are potentially dynamic based on the input tensor value availability at the
+// time of prepare.
 inline bool IsConstantTensor(const TfLiteTensor* tensor) {
   return tensor->allocation_type == kTfLiteMmapRo;
 }
@@ -105,6 +109,14 @@
   }
 }
 
+// Sets tensor to persistent and read-only.
+inline void SetTensorToPersistentRo(TfLiteTensor* tensor) {
+  if (tensor->allocation_type != kTfLitePersistentRo) {
+    tensor->allocation_type = kTfLitePersistentRo;
+    tensor->data.raw = nullptr;
+  }
+}
+
 // Determines whether it is a hybrid op - one that has float inputs and
 // quantized weights.
 inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {
diff --git a/tensorflow/lite/kernels/rank.cc b/tensorflow/lite/kernels/rank.cc
index 8e27ebc..53fd92f 100644
--- a/tensorflow/lite/kernels/rank.cc
+++ b/tensorflow/lite/kernels/rank.cc
@@ -30,19 +30,23 @@
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
 
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
   output->type = kTfLiteInt32;
 
+  // By design, the input shape is always known at the time of Prepare, even
+  // if the preceding op that generates |input| is dynamic. Thus, we can
+  // always compute the rank immediately, without waiting for Eval.
+  SetTensorToPersistentRo(output);
+
   // Rank produces a 0-D int32 Tensor representing the rank of input.
   TfLiteIntArray* output_size = TfLiteIntArrayCreate(0);
-  return context->ResizeTensor(context, output, output_size);
-}
+  TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
 
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
   TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0);
 
+  // Immediately propagate the known rank to the output tensor. This allows
+  // downstream ops that rely on the value to use it during prepare.
   if (output->type == kTfLiteInt32) {
     int32_t* output_data = GetTensorData<int32_t>(output);
     *output_data = NumDimensions(input);
@@ -53,6 +57,10 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  return kTfLiteOk;
+}
+
 }  // namespace rank
 
 TfLiteRegistration* Register_RANK() {
diff --git a/tensorflow/lite/kernels/rank_test.cc b/tensorflow/lite/kernels/rank_test.cc
index f3dc971..5373a0a 100644
--- a/tensorflow/lite/kernels/rank_test.cc
+++ b/tensorflow/lite/kernels/rank_test.cc
@@ -43,6 +43,9 @@
 
   std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+  TfLiteAllocationType GetOutputAllocationType() const {
+    return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type;
+  }
 
  private:
   int input_;
@@ -51,6 +54,13 @@
 
 TEST(RankOpTest, InputTypeFloat) {
   RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32);
+  ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo);
+
+  // Unlike most ops, Rank populates outputs in Prepare().
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
+  EXPECT_TRUE(model.GetOutputShape().empty());
+
+  // Invoke is superfluous and shouldn't change the output.
   model.Invoke();
 
   EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
@@ -59,7 +69,6 @@
 
 TEST(RankOpTest, InputTypeInt) {
   RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32);
-  model.Invoke();
 
   EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
   EXPECT_TRUE(model.GetOutputShape().empty());
@@ -67,7 +76,6 @@
 
 TEST(RankOpTest, ScalarTensor) {
   RankOpModel model({}, TensorType_FLOAT32);
-  model.Invoke();
 
   EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
   EXPECT_TRUE(model.GetOutputShape().empty());
@@ -75,7 +83,6 @@
 
 TEST(RankOpTest, EmptyTensor) {
   RankOpModel model({1, 0}, TensorType_FLOAT32);
-  model.Invoke();
 
   EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
   EXPECT_TRUE(model.GetOutputShape().empty());
diff --git a/tensorflow/lite/kernels/shape.cc b/tensorflow/lite/kernels/shape.cc
index 88794fe..d979f08 100644
--- a/tensorflow/lite/kernels/shape.cc
+++ b/tensorflow/lite/kernels/shape.cc
@@ -54,19 +54,22 @@
       return kTfLiteError;
   }
 
+  // By design, the input shape is always known at the time of Prepare, even
+  // if the preceding op that generates |input| is dynamic. Thus, we can
+  // always compute the shape immediately, without waiting for Eval.
+  SetTensorToPersistentRo(output);
+
   // Shape always produces a 1-dimensional output tensor, where each output
   // element is the length of the corresponding input tensor's dimension.
   TfLiteIntArray* output_size = TfLiteIntArrayCreate(1);
   output_size->data[0] = NumDimensions(input);
-  return context->ResizeTensor(context, output, output_size);
-}
+  TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
 
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
   TFLITE_DCHECK_EQ(NumDimensions(output), 1);
   TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input));
 
+  // Immediately propagate the known shape to the output tensor. This allows
+  // downstream ops that rely on the value to use it during prepare.
   switch (output->type) {
     case kTfLiteInt32:
       ExtractShape(input, GetTensorData<int32_t>(output));
@@ -81,6 +84,10 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  return kTfLiteOk;
+}
+
 }  // namespace shape
 
 TfLiteRegistration* Register_SHAPE() {
diff --git a/tensorflow/lite/kernels/shape_test.cc b/tensorflow/lite/kernels/shape_test.cc
index 6a7dad4..3eeb83f 100644
--- a/tensorflow/lite/kernels/shape_test.cc
+++ b/tensorflow/lite/kernels/shape_test.cc
@@ -45,6 +45,9 @@
   int32_t GetOutputSize() { return GetTensorSize(output_); }
   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+  TfLiteAllocationType GetOutputAllocationType() const {
+    return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type;
+  }
 
  private:
   int input_;
@@ -54,6 +57,13 @@
 TEST(ShapeOpTest, OutTypeInt) {
   ShapeOpModel<int32_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
                               TensorType_INT32);
+  ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo);
+
+  // Unlike most ops, Rank populates outputs in Prepare().
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
+
+  // Invoke is superfluous and shouldn't change the output.
   model.Invoke();
 
   EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
@@ -63,7 +73,6 @@
 TEST(ShapeOpTest, OutTypeInt64) {
   ShapeOpModel<int64_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
                               TensorType_INT64);
-  model.Invoke();
 
   EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
@@ -71,7 +80,6 @@
 
 TEST(ShapeOpTest, ScalarTensor) {
   ShapeOpModel<int32_t> model({}, TensorType_FLOAT32, TensorType_INT32);
-  model.Invoke();
 
   EXPECT_EQ(model.GetOutputSize(), 0);
   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0}));
@@ -79,7 +87,6 @@
 
 TEST(ShapeOpTest, EmptyTensor) {
   ShapeOpModel<int32_t> model({1, 0}, TensorType_FLOAT32, TensorType_INT32);
-  model.Invoke();
 
   EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
diff --git a/tensorflow/lite/micro/apollo3evb/micro_time.cc b/tensorflow/lite/micro/apollo3evb/micro_time.cc
new file mode 100644
index 0000000..12c9ae5
--- /dev/null
+++ b/tensorflow/lite/micro/apollo3evb/micro_time.cc
@@ -0,0 +1,72 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Reference implementation of timer functions.  Platforms are not required to
+// implement these timer methods, but they are required to enable profiling.
+
+// On platforms that have a POSIX stack or C library, it can be written using
+// methods from <sys/time.h> or clock() from <time.h>.
+
+// To add an equivalent function for your own platform, create your own
+// implementation file, and place it in a subfolder with named after the OS
+// you're targeting. For example, see the Cortex M bare metal version in
+// tensorflow/lite/micro/bluepill/micro_timer.cc or the mbed one on
+// tensorflow/lite/micro/mbed/micro_timer.cc.
+
+#include "tensorflow/lite/micro/micro_time.h"
+
+// These are headers from Ambiq's Apollo3 SDK.
+#include "am_bsp.h"         // NOLINT
+#include "am_mcu_apollo.h"  // NOLINT
+#include "am_util.h"        // NOLINT
+
+namespace tflite {
+namespace {
+
+// Select CTIMER 1 as benchmarking timer on Sparkfun Edge. This timer must not
+// be used elsewhere.
+constexpr int kTimerNum = 1;
+
+// Clock set to operate at 12MHz.
+constexpr int kClocksPerSecond = 12e6;
+
+}  // namespace
+
+int32_t ticks_per_second() { return kClocksPerSecond; }
+
+// Calling this method enables a timer that runs for eternity. The user is
+// responsible for avoiding trampling on this timer's config, otherwise timing
+// measurements may no longer be valid.
+int32_t GetCurrentTimeTicks() {
+  // TODO(b/150808076): Split out initialization, intialize in interpreter.
+  static bool is_initialized = false;
+  if (!is_initialized) {
+    am_hal_ctimer_config_t timer_config;
+    // Operate as a 32-bit timer.
+    timer_config.ui32Link = 1;
+    // Set timer A to continuous mode at 12MHz.
+    timer_config.ui32TimerAConfig =
+        AM_HAL_CTIMER_FN_CONTINUOUS | AM_HAL_CTIMER_HFRC_12MHZ;
+
+    am_hal_ctimer_stop(kTimerNum, AM_HAL_CTIMER_BOTH);
+    am_hal_ctimer_clear(kTimerNum, AM_HAL_CTIMER_BOTH);
+    am_hal_ctimer_config(kTimerNum, &timer_config);
+    am_hal_ctimer_start(kTimerNum, AM_HAL_CTIMER_TIMERA);
+    is_initialized = true;
+  }
+  return CTIMERn(kTimerNum)->TMR0;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
index 2ed3e45..918192c 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
+++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
@@ -65,7 +65,11 @@
   ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2);
 
   // Shift right if shift amount is positive, left if shift amount is negative.
-  result_56 = AE_SLAASQ56S(result_56, shift_amount);
+  if (shift_amount >= 0) {
+    result_56 = AE_Q56S_SRA(result_56, shift_amount);
+  } else {
+    result_56 = AE_Q56S_SLA(result_56, -shift_amount);
+  }
 
   // Round off the bottom 16 bits.
   // Q48.0 / 2^16 -> Q32.0 aligned to 48 bits.
diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc
index c2b3216..54ce338 100644
--- a/tensorflow/lite/micro/micro_allocator.cc
+++ b/tensorflow/lite/micro/micro_allocator.cc
@@ -388,10 +388,8 @@
     return kTfLiteError;
   }
   subgraph_ = (*subgraphs)[0];
-  tensors_ = subgraph_->tensors();
-  operators_ = subgraph_->operators();
 
-  context_->tensors_size = tensors_->size();
+  context_->tensors_size = subgraph_->tensors()->size();
   context_->tensors =
       reinterpret_cast<TfLiteTensor*>(memory_allocator_->AllocateFromTail(
           sizeof(TfLiteTensor) * context_->tensors_size,
@@ -405,9 +403,9 @@
   }
 
   // Initialize runtime tensors in context_ using the flatbuffer.
-  for (size_t i = 0; i < tensors_->size(); ++i) {
+  for (size_t i = 0; i < subgraph_->tensors()->size(); ++i) {
     TfLiteStatus status = internal::InitializeRuntimeTensor(
-        memory_allocator_, *tensors_->Get(i), model_->buffers(),
+        memory_allocator_, *subgraph_->tensors()->Get(i), model_->buffers(),
         error_reporter_, &context_->tensors[i]);
     if (status != kTfLiteOk) {
       TF_LITE_REPORT_ERROR(error_reporter_, "Failed to initialize tensor %d",
@@ -472,7 +470,7 @@
 
   auto* output = reinterpret_cast<NodeAndRegistration*>(
       memory_allocator_->AllocateFromTail(
-          sizeof(NodeAndRegistration) * operators_->size(),
+          sizeof(NodeAndRegistration) * subgraph_->operators()->size(),
           alignof(NodeAndRegistration)));
   if (output == nullptr) {
     TF_LITE_REPORT_ERROR(
@@ -483,8 +481,8 @@
   TfLiteStatus status = kTfLiteOk;
   auto* opcodes = model_->operator_codes();
   MicroBuiltinDataAllocator builtin_data_allocator(memory_allocator_);
-  for (size_t i = 0; i < operators_->size(); ++i) {
-    const auto* op = operators_->Get(i);
+  for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
+    const auto* op = subgraph_->operators()->Get(i);
     size_t index = op->opcode_index();
     if (index >= opcodes->size()) {
       TF_LITE_REPORT_ERROR(error_reporter_,
@@ -567,7 +565,7 @@
 
     AllocationInfoBuilder builder(error_reporter_, &tmp_allocator);
     TF_LITE_ENSURE_STATUS(
-        builder.Init(tensors_->size(), scratch_buffer_count_));
+        builder.Init(subgraph_->tensors()->size(), scratch_buffer_count_));
     TF_LITE_ENSURE_STATUS(builder.AddTensors(subgraph_, context_->tensors));
     TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_handles_));
     const AllocationInfo* allocation_info = builder.Finish();
@@ -606,8 +604,8 @@
 
   // Data in variables need to be kept for the next invocation so allocating
   // them from the tail (persistent area).
-  if (AllocateVariables(tensors_, context_->tensors, memory_allocator_) !=
-      kTfLiteOk) {
+  if (AllocateVariables(subgraph_->tensors(), context_->tensors,
+                        memory_allocator_) != kTfLiteOk) {
     TF_LITE_REPORT_ERROR(
         error_reporter_,
         "Failed to allocate variables. Please increase arena size.");
diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h
index b16f814..6a6e1e0 100644
--- a/tensorflow/lite/micro/micro_allocator.h
+++ b/tensorflow/lite/micro/micro_allocator.h
@@ -135,8 +135,6 @@
   size_t scratch_buffer_count_ = 0;
 
   const SubGraph* subgraph_;
-  const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators_;
-  const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors_;
 };
 
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.cc b/tensorflow/lite/micro/micro_optional_debug_tools.cc
index 70f16c7..42c42ae 100644
--- a/tensorflow/lite/micro/micro_optional_debug_tools.cc
+++ b/tensorflow/lite/micro/micro_optional_debug_tools.cc
@@ -95,6 +95,8 @@
       return "kTfLiteArenaRw";
     case kTfLiteArenaRwPersistent:
       return "kTfLiteArenaRwPersistent";
+    case kTfLitePersistentRo:
+      return "kTfLitePersistentRo";
   }
   return "(invalid)";
 }
diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc
index c5ccdb9..2e25b0a 100644
--- a/tensorflow/lite/optional_debug_tools.cc
+++ b/tensorflow/lite/optional_debug_tools.cc
@@ -77,6 +77,8 @@
       return "kTfLiteArenaRw";
     case kTfLiteArenaRwPersistent:
       return "kTfLiteArenaRwPersistent";
+    case kTfLitePersistentRo:
+      return "kTfLitePersistentRo";
   }
   return "(invalid)";
 }
diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py
index 9ddd09e..1bcb2ce 100644
--- a/tensorflow/lite/python/lite_test.py
+++ b/tensorflow/lite/python/lite_test.py
@@ -269,9 +269,7 @@
                                                   [out_tensor])
     converter.inference_input_type = lite_constants.QUANTIZED_UINT8
     converter.inference_type = lite_constants.FLOAT
-    converter.quantized_input_stats = {
-        'Placeholder': (0., 1.)
-    }  # mean, std_dev
+    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
     tflite_model = converter.convert()
     self.assertTrue(tflite_model)
 
@@ -1327,6 +1325,41 @@
     tflite_model = converter.convert()
     self.assertTrue(tflite_model)
 
+  def testResizeWithShape(self):
+    with ops.Graph().as_default():
+      # Construct a graph with a dynamically shapped input and an internal node
+      # that relies on the output of that input's shape.
+      in_tensor = array_ops.placeholder(
+          shape=[None, None], dtype=dtypes.float32)
+      in_tensor2 = [[1, 2], [3, 4]]
+      out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor))
+      sess = session.Session()
+
+    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+                                                  [out_tensor])
+    converter.experimental_new_converter = True
+    tflite_model = converter.convert()
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    input_details = interpreter.get_input_details()
+    self.assertLen(input_details, 1)
+    self.assertTrue(([1, 1] == input_details[0]['shape']).all())
+    self.assertTrue(([-1, -1] == input_details[0]['shape_signature']).all())
+
+    # Resize tensor and invoke.
+    interpreter.resize_tensor_input(0, [4])
+    interpreter.allocate_tensors()
+    interpreter.invoke()
+
+    # The output should be reshaped properly according to the resized input.
+    output_details = interpreter.get_output_details()
+    self.assertLen(output_details, 1)
+    self.assertEqual(np.int32, output_details[0]['dtype'])
+    self.assertTrue(([4] == output_details[0]['shape']).all())
+    output_data = interpreter.get_tensor(output_details[0]['index'])
+    self.assertTrue(([1, 2, 3, 4] == output_data).all())
+
   def testResizingIntermediateDynamicTensor(self):
     # This is a regression test for the case where shape of dynamic output
     # tensors changes between invocations.
diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD
index 9d50f1a..df85f65 100644
--- a/tensorflow/lite/testing/BUILD
+++ b/tensorflow/lite/testing/BUILD
@@ -329,7 +329,7 @@
             "//tensorflow/core:android_tensorflow_lib",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib",
+            "//tensorflow/core:portable_tensorflow_lib",
         ],
     }),
 )
@@ -368,7 +368,7 @@
             "//tensorflow/core:android_tensorflow_lib",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib",
+            "//tensorflow/core:portable_tensorflow_lib",
         ],
     }),
 )
@@ -408,7 +408,7 @@
             "//tensorflow/core:android_tensorflow_lib",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib",
+            "//tensorflow/core:portable_tensorflow_lib",
         ],
     }),
 )
@@ -443,7 +443,7 @@
             "//tensorflow/core:android_tensorflow_lib",
         ],
         "//tensorflow:ios": [
-            "//tensorflow/core:ios_tensorflow_lib",
+            "//tensorflow/core:portable_tensorflow_lib",
         ],
     }),
 )
diff --git a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
index 9816cc1..171d522 100644
--- a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -271,8 +271,8 @@
   const double magnitude =
       std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
   const double tolerated = 1e-6 * magnitude;
-  return std::abs(minmax1.min - minmax2.min) < tolerated &&
-         std::abs(minmax1.max - minmax2.max) < tolerated;
+  return std::abs(minmax1.min - minmax2.min) <= tolerated &&
+         std::abs(minmax1.max - minmax2.max) <= tolerated;
 }
 
 // Propagates MinMax from any of the listed arrays, to all others.
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index 57b791a..917fd24 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -1118,6 +1118,7 @@
         GetVersioningOpSig(builtin_op(), op_signature);
     op_sig.options.resize.half_pixel_centers =
         resize_bilinear_op.half_pixel_centers;
+    op_sig.options.resize.align_corners = resize_bilinear_op.align_corners;
     return ::tflite::GetBuiltinOperatorVersion(op_sig);
   }
 };
@@ -1147,6 +1148,7 @@
     ::tflite::OpSignature op_sig =
         GetVersioningOpSig(builtin_op(), op_signature);
     op_sig.options.resize.half_pixel_centers = resize_nn_op.half_pixel_centers;
+    op_sig.options.resize.align_corners = resize_nn_op.align_corners;
     return ::tflite::GetBuiltinOperatorVersion(op_sig);
   }
 };
diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD
index 3570722..f6cb717 100644
--- a/tensorflow/lite/tools/benchmark/BUILD
+++ b/tensorflow/lite/tools/benchmark/BUILD
@@ -142,6 +142,7 @@
         ":profiling_listener",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:string_util",
+        "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/profiling:platform_profiler",
         "//tensorflow/lite/profiling:profile_summary_formatter",
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 489780e..969713c 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -29,6 +29,7 @@
 #include "absl/base/attributes.h"
 #include "absl/strings/numbers.h"
 #include "ruy/profiler/profiler.h"  // from @ruy
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/op_resolver.h"
@@ -596,17 +597,20 @@
   return kTfLiteOk;
 }
 
-TfLiteStatus BenchmarkTfLiteModel::Init() {
-  TF_LITE_ENSURE_STATUS(LoadModel());
-
+TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() {
   auto resolver = GetOpResolver();
-
   const int32_t num_threads = params_.Get<int32_t>("num_threads");
   tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
   if (!interpreter_) {
-    TFLITE_LOG(ERROR) << "Failed to construct interpreter";
+    TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
     return kTfLiteError;
   }
+  return kTfLiteOk;
+}
+
+TfLiteStatus BenchmarkTfLiteModel::Init() {
+  TF_LITE_ENSURE_STATUS(LoadModel());
+  TF_LITE_ENSURE_STATUS(InitInterpreter());
 
   // Install profilers if necessary right after interpreter is created so that
   // any memory allocations inside the TFLite runtime could be recorded if the
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
index b56390b..cc87743 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
@@ -74,6 +74,9 @@
   // Allow subclasses to create a customized Op resolver during init.
   virtual std::unique_ptr<tflite::OpResolver> GetOpResolver() const;
 
+  // Allow subclass to initialize a customized tflite interpereter.
+  virtual TfLiteStatus InitInterpreter();
+
   // Create a BenchmarkListener that's specifically for TFLite profiling if
   // necessary.
   virtual std::unique_ptr<BenchmarkListener> MayCreateProfilingListener() const;
diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
index 9657c7e..ab150e8 100644
--- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
+++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
@@ -321,15 +321,23 @@
   void* data;
 } TfLitePtrUnion;
 
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+// Memory allocation strategies.
+//  * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
+//  * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
+//        and available during eval.
+//  * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
+//        only available during eval.
+//  * kTfLiteDynamic: Allocated during eval, or for string tensors.
+//  * kTfLitePersistentRo: Allocated and populated during prepare. This is
+//        useful for tensors that can be computed during prepare and treated
+//        as constant inputs for downstream ops (also in prepare).
 typedef enum TfLiteAllocationType {
   kTfLiteMemNone = 0,
   kTfLiteMmapRo,
   kTfLiteArenaRw,
   kTfLiteArenaRwPersistent,
   kTfLiteDynamic,
+  kTfLitePersistentRo,
 } TfLiteAllocationType;
 
 // The delegates should use zero or positive integers to represent handles.
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index 56aa9d5..9022afc 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -363,13 +363,20 @@
       }
       return 1;
     case BuiltinOperator_RESIZE_BILINEAR:
-    case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
       if (op_sig.options.resize.half_pixel_centers) {
         return 3;
       } else if (op_sig.input_types.at(0) == TensorType_INT8) {
         return 2;
       }
       return 1;
+    case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
+      if (op_sig.options.resize.half_pixel_centers ||
+          op_sig.options.resize.align_corners) {
+        return 3;
+      } else if (op_sig.input_types.at(0) == TensorType_INT8) {
+        return 2;
+      }
+      return 1;
 
     case BuiltinOperator_MAXIMUM:
     case BuiltinOperator_MINIMUM:
@@ -612,6 +619,8 @@
       if (resize_bilinear_option) {
         op_sig.options.resize.half_pixel_centers =
             resize_bilinear_option->half_pixel_centers();
+        op_sig.options.resize.align_corners =
+            resize_bilinear_option->align_corners();
       }
     } break;
     case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: {
@@ -620,6 +629,7 @@
       if (resize_nn_option) {
         op_sig.options.resize.half_pixel_centers =
             resize_nn_option->half_pixel_centers();
+        op_sig.options.resize.align_corners = resize_nn_option->align_corners();
       }
     } break;
     // TODO(b/150176627): Add tests for GetOpSignature.
diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h
index fba6c94..4b0fe88 100644
--- a/tensorflow/lite/tools/versioning/op_version.h
+++ b/tensorflow/lite/tools/versioning/op_version.h
@@ -48,6 +48,7 @@
     } lstm;
     struct {
       bool half_pixel_centers;
+      bool align_corners;
     } resize;
     struct {
       int32_t num_dims;
diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc
index 7d9039f..f0d8259 100644
--- a/tensorflow/lite/tools/versioning/op_version_test.cc
+++ b/tensorflow/lite/tools/versioning/op_version_test.cc
@@ -594,4 +594,64 @@
                                                         TensorType_INT32}};
   EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
 }
+TEST(OpVersionTest, VersioningResizeBilinearTest) {
+  // Default.
+  OpSignature fake_op_sig = {
+      .op = BuiltinOperator_RESIZE_BILINEAR,
+      .input_types =
+          std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT32},
+      .output_types = std::vector<TensorType>{TensorType_FLOAT32},
+  };
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+  // align_corners=true is still version 1.
+  fake_op_sig.options.resize.align_corners = true;
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+  // half_pixel_centers=true must be version 3.
+  fake_op_sig.options.resize.align_corners = false;
+  fake_op_sig.options.resize.half_pixel_centers = true;
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+  // int8 input is version 2.
+  fake_op_sig = {
+      .op = BuiltinOperator_RESIZE_BILINEAR,
+      .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT32},
+      .output_types = std::vector<TensorType>{TensorType_INT8},
+  };
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+  fake_op_sig.options.resize.half_pixel_centers = true;
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+}
+TEST(OpVersionTest, VersioningResizeNearestNeighborTest) {
+  // Default.
+  OpSignature fake_op_sig = {
+      .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
+      .input_types =
+          std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT32},
+      .output_types = std::vector<TensorType>{TensorType_FLOAT32},
+  };
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+  // align_corners=true is version 3.
+  fake_op_sig.options.resize.align_corners = true;
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+  // half_pixel_centers=true must be version 3.
+  fake_op_sig.options.resize.align_corners = false;
+  fake_op_sig.options.resize.half_pixel_centers = true;
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+  // int8 input is version 2.
+  fake_op_sig = {
+      .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
+      .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT32},
+      .output_types = std::vector<TensorType>{TensorType_INT8},
+  };
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+  fake_op_sig.options.resize.align_corners = true;
+  EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+}
 }  // namespace tflite
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 0b046ea..a49e4b7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -230,6 +230,7 @@
         "//tensorflow/python/tools:module_util",
         "//tensorflow/python/tools/api/generator:create_python_api",
         "//tensorflow/python/tpu:tpu_noestimator",
+        "//tensorflow/python/types",
         "//third_party/py/numpy",
     ],
 )
@@ -655,15 +656,15 @@
         "@com_google_absl//absl/types:optional",
     ] + if_static(
         extra_deps = [
-            "//tensorflow/core:eager_service_proto_cc",
-            "//tensorflow/core:master_proto_cc",
-            "//tensorflow/core:worker_proto_cc",
+            "//tensorflow/core/protobuf:eager_service_proto_cc",
+            "//tensorflow/core/protobuf:master_proto_cc",
+            "//tensorflow/core/protobuf:worker_proto_cc",
             "//tensorflow/core:version_lib",
         ],
         otherwise = [
-            "//tensorflow/core:eager_service_proto_cc_headers_only",
-            "//tensorflow/core:master_proto_cc_headers_only",
-            "//tensorflow/core:worker_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:master_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:worker_proto_cc_headers_only",
         ],
     ),
 )
@@ -8049,14 +8050,14 @@
         "//tensorflow/core/platform",
     ] + if_static(
         extra_deps = [
-            "//tensorflow/core:eager_service_proto_cc",
-            "//tensorflow/core:master_proto_cc",
-            "//tensorflow/core:worker_proto_cc",
+            "//tensorflow/core/protobuf:eager_service_proto_cc",
+            "//tensorflow/core/protobuf:master_proto_cc",
+            "//tensorflow/core/protobuf:worker_proto_cc",
         ],
         otherwise = [
-            "//tensorflow/core:eager_service_proto_cc_headers_only",
-            "//tensorflow/core:master_proto_cc_headers_only",
-            "//tensorflow/core:worker_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:master_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:worker_proto_cc_headers_only",
         ],
     ),
 )
diff --git a/tensorflow/python/autograph/g3doc/reference/control_flow.md b/tensorflow/python/autograph/g3doc/reference/control_flow.md
index 79cc0f3..cf580af 100644
--- a/tensorflow/python/autograph/g3doc/reference/control_flow.md
+++ b/tensorflow/python/autograph/g3doc/reference/control_flow.md
@@ -164,7 +164,7 @@
 #### Python values modified in TensorFlow control flow become Tensors
 
 If a symbol is modified in a TensorFlow control flow statement, then it becomes
-a `tf.Tensor`, even if it started off as a Python promitive value.
+a `tf.Tensor`, even if it started off as a Python primitive value.
 
 For example, the conditional below will run as a `tf.cond` (its condition is a
 `tf.Tensor`), which in turn will cause `i` to become a `tf.Tensor`.
diff --git a/tensorflow/python/autograph/g3doc/reference/generated_code.md b/tensorflow/python/autograph/g3doc/reference/generated_code.md
index b62911b..389fa53 100644
--- a/tensorflow/python/autograph/g3doc/reference/generated_code.md
+++ b/tensorflow/python/autograph/g3doc/reference/generated_code.md
@@ -66,7 +66,7 @@
 ```
 
 `tf.autograph.to_code` is a shortcut to obtain the generated code, and it's
-equivalent with calling `inspect.getsource(tf.autograph.to_code(f))`.
+equivalent with calling `inspect.getsource(tf.autograph.to_graph(f))`.
 
 #### Recording diagnostic information: `tf.autograph.set_verbosity`
 
diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index c5a3a3d..5d835fd 100644
--- a/tensorflow/python/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -106,11 +106,12 @@
     with self.cached_session() as sess:
       self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
 
-  @test_util.run_v1_only("b/117943489")
+  @test_util.run_deprecated_v1
   def test_append_tensorarray(self):
     l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
     l1 = data_structures.list_append(l, 1)
     l2 = data_structures.list_append(l1, 2)
+
     with self.cached_session() as sess:
       self.assertAllEqual(self.evaluate(l1.stack()), [1])
       self.assertAllEqual(self.evaluate(l2.stack()), [1, 2])
diff --git a/tensorflow/python/autograph/utils/tensor_list_test.py b/tensorflow/python/autograph/utils/tensor_list_test.py
index bbbc3bf..017d97b 100644
--- a/tensorflow/python/autograph/utils/tensor_list_test.py
+++ b/tensorflow/python/autograph/utils/tensor_list_test.py
@@ -34,7 +34,6 @@
   def _shape(self, shape_tuple):
     return constant(shape_tuple, dtypes.int32)
 
-  @test_util.run_v1_only("b/117943489")
   def test_dynamic_list_append(self):
     l = []
     l = tl.dynamic_list_append(l, 1)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 17f559e..627979a 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -33,7 +33,7 @@
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 10)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 12)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py
index c1c2366..67dfadb 100644
--- a/tensorflow/python/data/experimental/ops/data_service_ops.py
+++ b/tensorflow/python/data/experimental/ops/data_service_ops.py
@@ -84,15 +84,29 @@
     if task_refresh_interval_hint_ms is None:
       task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
 
+    self._dataset_id = ops.convert_to_tensor(
+        dataset_id, dtype=dtypes.int64, name="dataset_id")
+    self._processing_mode = ops.convert_to_tensor(
+        processing_mode, dtype=dtypes.string, name="processing_mode")
+    self._address = ops.convert_to_tensor(
+        address, dtype=dtypes.string, name="address")
+    self._protocol = ops.convert_to_tensor(
+        protocol, dtype=dtypes.string, name="protocol")
+    self._job_name = ops.convert_to_tensor(
+        job_name, dtype=dtypes.string, name="job_name")
+    self._max_outstanding_requests = ops.convert_to_tensor(
+        max_outstanding_requests,
+        dtype=dtypes.int64,
+        name="max_outstanding_requests")
     self._element_spec = input_dataset.element_spec
 
     variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
-        dataset_id=dataset_id,
-        processing_mode=processing_mode,
-        address=address,
-        protocol=protocol,
-        job_name=job_name,
-        max_outstanding_requests=max_outstanding_requests,
+        dataset_id=self._dataset_id,
+        processing_mode=self._processing_mode,
+        address=self._address,
+        protocol=self._protocol,
+        job_name=self._job_name,
+        max_outstanding_requests=self._max_outstanding_requests,
         task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
         iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
         ),
@@ -297,5 +311,8 @@
   Returns:
     Dataset: A `Dataset` of the elements produced by the data service.
   """
-  return _distribute(processing_mode, service, job_name,
-                     max_outstanding_requests)
+  return _distribute(
+      processing_mode=processing_mode,
+      service=service,
+      job_name=job_name,
+      max_outstanding_requests=max_outstanding_requests)
diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py
index eac1c67..217c586 100644
--- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py
@@ -216,6 +216,21 @@
       self.assertEqual(i, val)
 
   @combinations.generate(test_base.eager_only_combinations())
+  def testMaxOutstandingRequests(self):
+    num_elements = 10
+    num_workers = 3
+    service = self.create_cluster(num_workers)
+    ds = dataset_ops.Dataset.range(num_elements)
+    ds = ds.apply(
+        data_service_ops._distribute(
+            "parallel_epochs",
+            service,
+            max_outstanding_requests=1,
+            task_refresh_interval_hint_ms=20))
+    self.assertCountEqual(num_workers * list(range(num_elements)),
+                          self.getDatasetOutput(ds))
+
+  @combinations.generate(test_base.eager_only_combinations())
   def testInsideFunction(self):
     num_workers = 3
     num_elements = 10
diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py
index a80ca83..408cad2 100644
--- a/tensorflow/python/distribute/multi_worker_test_base.py
+++ b/tensorflow/python/distribute/multi_worker_test_base.py
@@ -50,6 +50,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import coordinator
 from tensorflow.python.training import server_lib
+from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
 from tensorflow.python.util.compat import collections_abc
 
@@ -559,6 +560,10 @@
     return subprocess.Popen(
         cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
 
+  @deprecation.deprecated(
+      None, '`run_multiple_tasks_in_processes` is deprecated; any new test '
+      'requiring multiple processes should use `multi_process_runner` for '
+      'better support of log printing, streaming, and more functionality.')
   def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec):
     """Run `cmd_args` in a process for each task in `cluster_spec`."""
     processes = {}
@@ -570,6 +575,10 @@
         processes[task_type].append(p)
     return processes
 
+  @deprecation.deprecated(
+      None, '`join_independent_workers` is deprecated; any new test '
+      'requiring multiple processes should use `multi_process_runner` for '
+      'better support of log printing, streaming, and more functionality.')
   def join_independent_workers(self, worker_processes):
     return_codes = []
     for p in nest.flatten(worker_processes):
@@ -585,6 +594,10 @@
     for return_code in return_codes:
       self.assertEqual(return_code, 0)
 
+  @deprecation.deprecated(
+      None, '`stream_stderr` is deprecated; any new test '
+      'requiring multiple processes should use `multi_process_runner` for '
+      'better support of log printing, streaming, and more functionality.')
   def stream_stderr(self, processes, print_only_first=False):
     """Consume stderr of all processes and print to stdout.
 
diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD
index e7526a5..930816d 100644
--- a/tensorflow/python/distribute/parallel_device/BUILD
+++ b/tensorflow/python/distribute/parallel_device/BUILD
@@ -1,4 +1,8 @@
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+
 package(
+    default_visibility = ["//tensorflow:internal"],
     licenses = ["notice"],  # Apache 2.0
 )
 
@@ -13,6 +17,7 @@
     srcs = ["parallel_device.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":parallel_device_ops",
         ":saving",
         "//tensorflow/python:_pywrap_parallel_device",
     ],
@@ -25,6 +30,25 @@
     deps = ["//tensorflow/python:framework_ops"],
 )
 
+tf_gen_op_wrapper_py(
+    name = "parallel_device_ops_py",
+    out = "gen_parallel_device_ops.py",
+    deps = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
+)
+
+tf_custom_op_library(
+    name = "_parallel_device_ops.so",
+    srcs = ["//tensorflow/c/eager/parallel_device:parallel_device_ops_srcs"],
+)
+
+tf_custom_op_py_library(
+    name = "parallel_device_ops",
+    dso = [":_parallel_device_ops.so"],
+    kernels = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
+    visibility = ["//tensorflow:internal"],
+    deps = [":parallel_device_ops_py"],
+)
+
 py_test(
     name = "parallel_device_test",
     srcs = ["parallel_device_test.py"],
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py
index 982b061..2dbdc65 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device.py
@@ -22,11 +22,17 @@
 import threading
 
 from tensorflow.python import _pywrap_parallel_device
+from tensorflow.python.distribute.parallel_device import gen_parallel_device_ops
 from tensorflow.python.distribute.parallel_device import saving
 from tensorflow.python.eager import context
+from tensorflow.python.framework import load_library
 from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
 from tensorflow.python.tpu.ops import tpu_ops
 
+load_library.load_op_library(
+    resource_loader.get_path_to_datafile("_parallel_device_ops.so"))
+
 _next_device_number = 0
 _next_device_number_lock = threading.Lock()
 
@@ -58,6 +64,8 @@
     device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
         self.name, self.components)
     context.register_custom_device(device, self.name, device_info)
+    with ops.device(self.name):
+      self._device_ids = gen_parallel_device_ops.device_id()
 
   def pack(self, tensors):
     """Create a tensor on the parallel device from a sequence of tensors.
@@ -84,6 +92,18 @@
       return tpu_ops.tpu_replicated_output(
           parallel_tensor, num_replicas=len(self.components))
 
+  @property
+  def device_ids(self):
+    """A parallel tensor with scalar integers numbering component devices.
+
+    Each device ID is placed on its corresponding device, in the same order as
+    the `components` constructor argument.
+
+    Returns:
+      A parallel tensor containing 0 on the first device, 1 on the second, etc.
+    """
+    return self._device_ids
+
   # TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
   # to provide a hook for the custom device to create save specs/etc., then call
   # that hook from the default variable implementation if the variable is on a
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
index d3f3417..e35eb60 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
@@ -119,6 +119,12 @@
     self.assertIn(self.device.components[0], outputs[0].backing_device)
     self.assertIn(self.device.components[1], outputs[1].backing_device)
 
+  def test_device_id(self):
+    device_ids = self.device.unpack(self.device.device_ids)
+    self.assertAllClose([0, 1], device_ids)
+    self.assertIn(self.device.components[0], device_ids[0].backing_device)
+    self.assertIn(self.device.components[1], device_ids[1].backing_device)
+
   def test_collective_reduce(self):
     with ops.device(self.device.name):
       x = self.device.pack(
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 6e51b84..b574c52 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -96,35 +96,34 @@
 
 @tf_export("distribute.experimental.TPUStrategy", v1=[])
 class TPUStrategy(distribute_lib.Strategy):
-  """TPU distribution strategy implementation."""
+  """TPU distribution strategy implementation.
+
+  To construct a TPUStrategy object, you need to run the
+  initialization code as below:
+
+  >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+  >>> tf.config.experimental_connect_to_cluster(resolver)
+  >>> tf.tpu.experimental.initialize_tpu_system(resolver)
+  >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
+
+  While using distribution strategies, the variables created within strategy's
+  scope will be replicated across all the replicas and can be kept in sync
+  using all-reduce algorithms.
+
+  To run TF2 programs on TPUs, you can either use `.compile` and
+  `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
+  training loop by calling `strategy.run` directly. Note that
+  TPUStrategy doesn't support pure eager execution, so please make sure the
+  function passed into `strategy.run` is a `tf.function` or
+  `strategy.run` is called inside a `tf.function` if eager
+  behavior is enabled.
+  """
 
   def __init__(self,
                tpu_cluster_resolver=None,
                device_assignment=None):
     """Synchronous training in TPU donuts or Pods.
 
-    To construct a TPUStrategy object, you need to run the
-    initialization code as below:
-
-    ```python
-    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
-    tf.config.experimental_connect_to_cluster(resolver)
-    tf.tpu.experimental.initialize_tpu_system(resolver)
-    strategy = tf.distribute.experimental.TPUStrategy(resolver)
-    ```
-
-    While using distribution strategies, the variables created within strategy's
-    scope will be replicated across all the replicas and can be kept in sync
-    using all-reduce algorithms.
-
-    To run TF2 programs on TPUs, you can either use `.compile` and
-    `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
-    training loop by calling `strategy.run` directly. Note that
-    TPUStrategy doesn't support pure eager execution, so please make sure the
-    function passed into `strategy.run` is a `tf.function` or
-    `strategy.run` is called inside a `tf.function` if eager
-    behavior is enabled.
-
     Args:
       tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
         which provides information about the TPU cluster.
@@ -209,26 +208,26 @@
     Users can pass strategy specific options to `options` argument. An example
     to enable bucketizing dynamic shapes in `TPUStrategy.run`
     is:
-    ```python
 
-    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
-    tf.config.experimental_connect_to_cluster(resolver)
-    tf.tpu.experimental.initialize_tpu_system(resolver)
-    strategy = tf.distribute.experimental.TPUStrategy(tpu='')
+    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+    >>> tf.config.experimental_connect_to_cluster(resolver)
+    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
+    >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
 
-    options = tf.distribute.RunOptions()
-    options.experimental_bucketizing_dynamic_shape = True
+    >>> options = tf.distribute.RunOptions(
+    ...     experimental_bucketizing_dynamic_shape=True)
 
-    iterator = iter(inputs)
+    >>> dataset = tf.data.Dataset.range(
+    ...    strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
+    ...        strategy.num_replicas_in_sync, drop_remainder=True)
+    >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
 
-    @tf.function()
-    def step_fn(inputs):
-      output = tf.reduce_sum(inputs)
-      return output
+    >>> @tf.function()
+    ... def step_fn(inputs):
+    ...  output = tf.reduce_sum(inputs)
+    ...  return output
 
-      strategy.run(step_fn, args=(next(iterator),),
-                                   options=options)
-    ```
+    >>> strategy.run(step_fn, args=(next(input_iterator),), options=options)
 
     Args:
       fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 4fe3d28..444915a 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -38,6 +38,7 @@
 from tensorflow.python.training.saving import saveable_object
 from tensorflow.python.training.saving import saveable_object_util
 from tensorflow.python.training.tracking import base as trackable
+from tensorflow.python.types import core
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
@@ -422,7 +423,8 @@
     return hash((self.name, self.graph, self.traceback, self.type))
 
 
-class DistributedVariable(DistributedDelegate, variables_lib.Variable):
+class DistributedVariable(DistributedDelegate, variables_lib.Variable,
+                          core.Tensor):
   """Holds a map from replica to variables."""
 
   # TODO(josh11b): Support changing the set of variables if e.g. if new
@@ -741,9 +743,6 @@
     pass
 
 
-ops.register_dense_tensor_like_type(DistributedVariable)
-
-
 def _validate_colocate_extended(v, extended):
   variable_strategy = v._distribute_strategy  # pylint: disable=protected-access
   if variable_strategy.extended is not extended:
@@ -1380,7 +1379,7 @@
   return val
 
 
-class AggregatingVariable(variables_lib.Variable):
+class AggregatingVariable(variables_lib.Variable, core.Tensor):
   """A wrapper around a variable that aggregates updates across replicas."""
 
   def __init__(self, strategy, v, aggregation):
@@ -1649,4 +1648,3 @@
 
 ops.register_tensor_conversion_function(AggregatingVariable,
                                         _tensor_conversion_aggregate)
-ops.register_dense_tensor_like_type(AggregatingVariable)
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index daa7e55..67ed86b 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -56,6 +56,7 @@
 from tensorflow.python.tpu import tpu_strategy_util
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.training.tracking import util as trackable_utils
+from tensorflow.python.types import core
 from tensorflow.python.util import nest
 
 
@@ -623,10 +624,10 @@
       v = variables_lib.Variable(
           0., synchronization=synchronization, aggregation=aggregation)
     # In cross replica context.
-    self.assertTrue(ops.is_dense_tensor_like(v))
+    self.assertIsInstance(v, core.Tensor)
     # In replica context.
     distribution.run(
-        lambda v: self.assertTrue(ops.is_dense_tensor_like(v)), args=(v,))
+        lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
 
   def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
                                         aggregation):
@@ -645,9 +646,9 @@
       # values is not allowed when aggregation is SUM. See
       # `cross_device_ops.reduce_non_distributed_value`.
       delta = array_ops.identity(1.)
-      self.assertTrue(ops.is_dense_tensor_like(v.assign(delta)))
-      self.assertTrue(ops.is_dense_tensor_like(v.assign_sub(delta)))
-      self.assertTrue(ops.is_dense_tensor_like(v.assign_add(delta)))
+      self.assertIsInstance(v.assign(delta), core.Tensor)
+      self.assertIsInstance(v.assign_sub(delta), core.Tensor)
+      self.assertIsInstance(v.assign_add(delta), core.Tensor)
 
     # In cross replica context we return a PerReplica which is not Tensor like
     # yet.
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index f2f1327..227fca5 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -180,22 +180,18 @@
         func()  # Warmup.
       self._run(func, 3000)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_create_float_constant(self):
     self._benchmark_create_constant(42.0, dtype=None)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_create_float_constant_uncached(self):
     self._benchmark_create_constant(42.0, dtype=None, cached=False)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_create_int32_constant(self):
     if context.num_gpus():
       return  # int32 constants are always allocated on CPU.
 
     self._benchmark_create_constant(42, dtype=dtypes.int32)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_create_int32_constant_uncached(self):
     if context.num_gpus():
       return  # int32 constants are always allocated on CPU.
@@ -211,21 +207,17 @@
         func()  # Warmup.
       self._run(func, 30000)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_add_float_scalars(self):
     self._benchmark_add(42.0, 24.0)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_add_int32_scalars(self):
     self._benchmark_add(42, 24)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_add_float_scalar_tensor(self):
     tensor_a = constant_op.constant(42.0)
     tensor_b = constant_op.constant(24.0)
     self._benchmark_add(tensor_a, tensor_b)
 
-  @test_util.disable_tfrt("Scalars are not handled correctly")
   def benchmark_add_int32_scalar_tensor(self):
     tensor_a = constant_op.constant(42)
     tensor_b = constant_op.constant(24)
diff --git a/tensorflow/python/eager/benchmarks_test_base.py b/tensorflow/python/eager/benchmarks_test_base.py
index 552d844..3d81d08 100644
--- a/tensorflow/python/eager/benchmarks_test_base.py
+++ b/tensorflow/python/eager/benchmarks_test_base.py
@@ -32,4 +32,6 @@
         "examples_per_sec": float("{0:.3f}".format(num_iters / total_time)),
         "us_per_example": float("{0:.3f}".format(total_time * 1e6 / num_iters))
     }
-    self.report_benchmark(iters=num_iters, wall_time=mean_us, extras=extras)
+    benchmark_name = self._get_benchmark_name()
+    self.report_benchmark(
+        iters=num_iters, wall_time=mean_us, extras=extras, name=benchmark_name)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 9b8f7cf..43652d5 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -62,6 +62,7 @@
 from tensorflow.python.ops import control_flow_util
 from tensorflow.python.platform import app
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.types import core as core_tf_types
 from tensorflow.python.types import internal
 from tensorflow.python.util import compat
 from tensorflow.python.util import decorator_utils
@@ -213,53 +214,11 @@
   return None
 
 
-_TENSOR_LIKE_TYPES = tuple()
-
-
+# Deprecated - do not use.
+# This API to avoid breaking estimator and tensorflow-mesh which depend on this
+# internal API. The stub should be safe to use after TF 2.3 is released.
 def is_dense_tensor_like(t):
-  """EXPERIMENTAL: Returns true if `t` implements the tensor interface.
-
-  See `register_dense_tensor_like_type()` for the current definition of a
-  "tensor-like type".
-
-  Args:
-    t: An object.
-
-  Returns:
-    True iff `t` is an instance of one of the registered "tensor-like" types.
-  """
-  return isinstance(t, _TENSOR_LIKE_TYPES)
-
-
-def register_dense_tensor_like_type(tensor_type):
-  """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
-
-  A "tensor-like type" can represent a single dense tensor, and implements
-  the `name`, `dtype` and `shape` properties.
-
-  Args:
-    tensor_type: A type implementing the tensor interface.
-
-  Raises:
-    TypeError: If `tensor_type` does not implement the tensor interface.
-  """
-  if not (hasattr(tensor_type, "name") and
-          isinstance(tensor_type.name, property)):
-    raise TypeError("Type %s does not define a `name` property" %
-                    tensor_type.__name__)
-  if not (hasattr(tensor_type, "dtype") and
-          isinstance(tensor_type.dtype, property)):
-    raise TypeError("Type %s does not define a `dtype` property" %
-                    tensor_type.__name__)
-  if not (hasattr(tensor_type, "shape") and
-          isinstance(tensor_type.shape, property)):
-    raise TypeError("Type %s does not define a `shape` property" %
-                    tensor_type.__name__)
-  # We expect this list to be small, so choose quadratic complexity
-  # for registration, so that we have a tuple that can be used for
-  # more efficient `isinstance` checks later.
-  global _TENSOR_LIKE_TYPES
-  _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
+  return isinstance(t, core_tf_types.Tensor)
 
 
 def uid():
@@ -304,7 +263,7 @@
 
 # TODO(mdan): This object should subclass Symbol, not just Tensor.
 @tf_export("Tensor")
-class Tensor(internal.NativeObject):
+class Tensor(internal.NativeObject, core_tf_types.Tensor):
   """A tensor is a multidimensional array of elements represented by a
 
   `tf.Tensor` object.  All elements are of a single known data type.
@@ -1305,9 +1264,6 @@
 EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
 
 
-register_dense_tensor_like_type(Tensor)
-
-
 @tf_export(v1=["convert_to_tensor"])
 def convert_to_tensor_v1(value,
                          dtype=None,
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 20f58a0..322df8f 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -3268,56 +3268,6 @@
         test_ops.old()
 
 
-class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
-
-  @test_util.disable_tfrt("Graph is not supported yet.")
-  def testSuccess(self):
-    op = ops.Operation(
-        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
-    t = op.outputs[0]
-    self.assertTrue(ops.is_dense_tensor_like(t))
-
-    v = variables.Variable([17])
-    self.assertTrue(ops.is_dense_tensor_like(v))
-
-  class BadClassNoName(object):
-    pass
-
-  class BadClassBadName(object):
-
-    def name(self):
-      pass
-
-  class BadClassNoDtype(object):
-
-    @property
-    def name(self):
-      pass
-
-  class BadClassBadDtype(object):
-
-    @property
-    def name(self):
-      pass
-
-    def dtype(self):
-      pass
-
-  def testBadClass(self):
-    with self.assertRaisesRegexp(TypeError, "`name`"):
-      ops.register_dense_tensor_like_type(
-          DenseTensorLikeTypeTest.BadClassNoName)
-    with self.assertRaisesRegexp(TypeError, "`name`"):
-      ops.register_dense_tensor_like_type(
-          DenseTensorLikeTypeTest.BadClassBadName)
-    with self.assertRaisesRegexp(TypeError, "`dtype`"):
-      ops.register_dense_tensor_like_type(
-          DenseTensorLikeTypeTest.BadClassNoDtype)
-    with self.assertRaisesRegexp(TypeError, "`dtype`"):
-      ops.register_dense_tensor_like_type(
-          DenseTensorLikeTypeTest.BadClassBadDtype)
-
-
 class NameScopeTest(test_util.TensorFlowTestCase):
 
   def testStripAndPrependScope(self):
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 5038859..968b635 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -26,6 +26,7 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.types import core
 from tensorflow.python.types import internal
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
@@ -1009,7 +1010,7 @@
     `True` if `x` is a tensor or "tensor-like", `False` if not.
   """
   return (isinstance(x, internal.NativeObject) or
-          ops.is_dense_tensor_like(x) or
+          isinstance(x, core.Tensor) or
           getattr(x, "is_tensor_like", False))
 
 
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 503f6cf..2700fae 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -393,6 +393,9 @@
           False, shape=(), name='keras_learning_phase')
 
 
+@deprecated('2020-10-11',
+            'Simply pass a True/False value to the `training` argument '
+            'of the `__call__` method of your layer or model.')
 @keras_export('keras.backend.set_learning_phase')
 def set_learning_phase(value):
   """Sets the learning phase to a fixed value.
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 6748a57..db326ea 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -307,14 +307,20 @@
       end_hook_name = hook_name
       begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
 
-      threshold_time = 0.5 * batch_time
+      threshold_time = 1.5 * batch_time
       warning_msg = ('Callbacks method `{hook}` is slow compared to '
-                     'the batch time. Check your callbacks.')
+                     'the batch time (batch time: {batch_time:.4f}s vs '
+                     '`{hook}` time: {cbk_time:.4f}s). Check your callbacks.')
       if self._timing[begin_hook_name] > threshold_time:
-        logging.warning(warning_msg.format(hook=begin_hook_name))
+        logging.warning(warning_msg.format(
+            hook=begin_hook_name,
+            batch_time=batch_time,
+            cbk_time=self._timing[begin_hook_name]))
       if self._timing[end_hook_name] > threshold_time:
-        logging.warning(warning_msg.format(hook=end_hook_name))
-
+        logging.warning(warning_msg.format(
+            hook=end_hook_name,
+            batch_time=batch_time,
+            cbk_time=self._timing[end_hook_name]))
       self._check_timing = False
       self._batch_start_time = None
 
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 9d15f87..2f1256e 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -302,8 +302,8 @@
           epochs=10,
           callbacks=[SleepCallback()])
     warning_msg = ('Callbacks method `on_train_batch_end` is slow compared '
-                   'to the batch time. Check your callbacks.')
-    self.assertIn(warning_msg, warning_messages)
+                   'to the batch time')
+    self.assertIn(warning_msg, '\n'.join(warning_messages))
 
   @keras_parameterized.run_with_all_model_types(exclude_models='functional')
   @keras_parameterized.run_all_keras_modes
diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py
index 0a40ce3..16188af 100644
--- a/tensorflow/python/keras/engine/training_v1.py
+++ b/tensorflow/python/keras/engine/training_v1.py
@@ -62,6 +62,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
+from tensorflow.python.types import core
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
@@ -3143,7 +3144,7 @@
     The possibly-converted 'value'.
   """
   if issparse is not None and issparse(value):
-    if ops.is_dense_tensor_like(expected_input):
+    if isinstance(expected_input, core.Tensor):
       if ops.executing_eagerly_outside_functions():
         # In TF2 we do not silently densify sparse matrices.
         raise ValueError('A SciPy sparse matrix was passed to a model '
diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
index c43ca21..29e5a68 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
@@ -23,9 +23,10 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
+from tensorflow.python.types import core
 
 
-class AutoCastVariable(variables.Variable):
+class AutoCastVariable(variables.Variable, core.Tensor):
   """Variable that will cast itself to a different dtype in applicable contexts.
 
   This class wraps a floating-point `tf.Variable`. It emulates the variable
@@ -417,7 +418,6 @@
 
 ops.register_tensor_conversion_function(AutoCastVariable,
                                         AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
-ops.register_dense_tensor_like_type(AutoCastVariable)
 
 
 def create_autocast_variable(variable):
diff --git a/tensorflow/python/keras/preprocessing/BUILD b/tensorflow/python/keras/preprocessing/BUILD
index 403bc6e..24260fb 100644
--- a/tensorflow/python/keras/preprocessing/BUILD
+++ b/tensorflow/python/keras/preprocessing/BUILD
@@ -85,6 +85,7 @@
     deps = [
         ":image",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/keras",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py
index 3af573f..953962c 100644
--- a/tensorflow/python/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/preprocessing/image.py
@@ -14,6 +14,7 @@
 # ==============================================================================
 # pylint: disable=invalid-name
 # pylint: disable=g-import-not-at-top
+# pylint: disable=g-classes-have-attributes
 """Set of tools for real-time data augmentation on image data.
 """
 from __future__ import absolute_import
@@ -35,6 +36,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import image_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import keras_export
 
@@ -49,6 +51,7 @@
 apply_affine_transform = image.apply_affine_transform
 
 
+@keras_export('keras.preprocessing.image.smart_resize', v1=[])
 def smart_resize(x, size, interpolation='bilinear'):
   """Resize images to a target size without aspect ratio distortion.
 
@@ -65,7 +68,7 @@
   ```
 
   However, if you do this, you distort the aspect ratio of your images, since
-  in general they do not all have the same aspect ratio. This is
+  in general they do not all have the same aspect ratio as `size`. This is
   fine in many cases, but not always (e.g. for GANs this can be a problem).
 
   Note that passing the argument `preserve_aspect_ratio=True` to `resize`
@@ -458,6 +461,123 @@
         **kwargs)
 
 
+class DataFrameIterator(image.DataFrameIterator, Iterator):
+  """Iterator capable of reading images from a directory on disk as a dataframe.
+
+  Arguments:
+      dataframe: Pandas dataframe containing the filepaths relative to
+        `directory` (or absolute paths if `directory` is None) of the images in
+        a string column. It should include other column/s
+          depending on the `class_mode`: - if `class_mode` is `"categorical"`
+            (default value) it must include the `y_col` column with the class/es
+            of each image. Values in column can be string/list/tuple if a single
+            class or list/tuple if multiple classes. - if `class_mode` is
+            `"binary"` or `"sparse"` it must include the given `y_col` column
+            with class values as strings. - if `class_mode` is `"raw"` or
+            `"multi_output"` it should contain the columns specified in `y_col`.
+            - if `class_mode` is `"input"` or `None` no extra column is needed.
+      directory: string, path to the directory to read images from. If `None`,
+        data in `x_col` column should be absolute paths.
+      image_data_generator: Instance of `ImageDataGenerator` to use for random
+        transformations and normalization. If None, no transformations and
+        normalizations are made.
+      x_col: string, column in `dataframe` that contains the filenames (or
+        absolute paths if `directory` is `None`).
+      y_col: string or list, column/s in `dataframe` that has the target data.
+      weight_col: string, column in `dataframe` that contains the sample
+          weights. Default: `None`.
+      target_size: tuple of integers, dimensions to resize input images to.
+      color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. Color mode to read
+        images.
+      classes: Optional list of strings, classes to use (e.g. `["dogs",
+        "cats"]`). If None, all classes in `y_col` will be used.
+      class_mode: one of "binary", "categorical", "input", "multi_output",
+          "raw", "sparse" or None. Default: "categorical".
+          Mode for yielding the targets:
+          - `"binary"`: 1D numpy array of binary labels,
+          - `"categorical"`: 2D numpy array of one-hot encoded labels. Supports
+            multi-label output.
+          - `"input"`: images identical to input images (mainly used to work
+            with autoencoders),
+          - `"multi_output"`: list with the values of the different columns,
+          - `"raw"`: numpy array of values in `y_col` column(s),
+          - `"sparse"`: 1D numpy array of integer labels, - `None`, no targets
+            are returned (the generator will only yield batches of image data,
+            which is useful to use in `model.predict_generator()`).
+      batch_size: Integer, size of a batch.
+      shuffle: Boolean, whether to shuffle the data between epochs.
+      seed: Random seed for data shuffling.
+      data_format: String, one of `channels_first`, `channels_last`.
+      save_to_dir: Optional directory where to save the pictures being yielded,
+        in a viewable format. This is useful for visualizing the random
+        transformations being applied, for debugging purposes.
+      save_prefix: String prefix to use for saving sample images (if
+        `save_to_dir` is set).
+      save_format: Format to use for saving sample images (if `save_to_dir` is
+        set).
+      subset: Subset of data (`"training"` or `"validation"`) if
+        validation_split is set in ImageDataGenerator.
+      interpolation: Interpolation method used to resample the image if the
+        target size is different from that of the loaded image. Supported
+        methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3
+        or newer is installed, "lanczos" is also supported. If PIL version 3.4.0
+        or newer is installed, "box" and "hamming" are also supported. By
+        default, "nearest" is used.
+      dtype: Dtype to use for the generated arrays.
+      validate_filenames: Boolean, whether to validate image filenames in
+        `x_col`. If `True`, invalid images will be ignored. Disabling this
+        option
+      can lead to speed-up in the instantiation of this class. Default: `True`.
+  """
+
+  def __init__(
+      self,
+      dataframe,
+      directory=None,
+      image_data_generator=None,
+      x_col='filename',
+      y_col='class',
+      weight_col=None,
+      target_size=(256, 256),
+      color_mode='rgb',
+      classes=None,
+      class_mode='categorical',
+      batch_size=32,
+      shuffle=True,
+      seed=None,
+      data_format='channels_last',
+      save_to_dir=None,
+      save_prefix='',
+      save_format='png',
+      subset=None,
+      interpolation='nearest',
+      dtype='float32',
+      validate_filenames=True):
+    super(DataFrameIterator, self).__init__(
+        dataframe=dataframe,
+        directory=directory,
+        image_data_generator=image_data_generator,
+        x_col=x_col,
+        y_col=y_col,
+        weight_col=weight_col,
+        target_size=target_size,
+        color_mode=color_mode,
+        classes=classes,
+        class_mode=class_mode,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        seed=seed,
+        data_format=data_format,
+        save_to_dir=save_to_dir,
+        save_prefix=save_prefix,
+        save_format=save_format,
+        subset=subset,
+        interpolation=interpolation,
+        dtype=dtype,
+        validate_filenames=validate_filenames
+    )
+
+
 @keras_export('keras.preprocessing.image.ImageDataGenerator')
 class ImageDataGenerator(image.ImageDataGenerator):
   """Generate batches of tensor image data with real-time data augmentation.
@@ -685,6 +805,302 @@
         validation_split=validation_split,
         **kwargs)
 
+  def flow(self,
+           x,
+           y=None,
+           batch_size=32,
+           shuffle=True,
+           sample_weight=None,
+           seed=None,
+           save_to_dir=None,
+           save_prefix='',
+           save_format='png',
+           subset=None):
+    """Takes data & label arrays, generates batches of augmented data.
+
+    Arguments:
+        x: Input data. Numpy array of rank 4 or a tuple. If tuple, the first
+          element should contain the images and the second element another numpy
+          array or a list of numpy arrays that gets passed to the output without
+          any modifications. Can be used to feed the model miscellaneous data
+          along with the images. In case of grayscale data, the channels axis of
+          the image array should have value 1, in case of RGB data, it should
+          have value 3, and in case of RGBA data, it should have value 4.
+        y: Labels.
+        batch_size: Int (default: 32).
+        shuffle: Boolean (default: True).
+        sample_weight: Sample weights.
+        seed: Int (default: None).
+        save_to_dir: None or str (default: None). This allows you to optionally
+          specify a directory to which to save the augmented pictures being
+          generated (useful for visualizing what you are doing).
+        save_prefix: Str (default: `''`). Prefix to use for filenames of saved
+          pictures (only relevant if `save_to_dir` is set).
+        save_format: one of "png", "jpeg"
+            (only relevant if `save_to_dir` is set). Default: "png".
+        subset: Subset of data (`"training"` or `"validation"`) if
+          `validation_split` is set in `ImageDataGenerator`.
+
+    Returns:
+        An `Iterator` yielding tuples of `(x, y)`
+            where `x` is a numpy array of image data
+            (in the case of a single image input) or a list
+            of numpy arrays (in the case with
+            additional inputs) and `y` is a numpy array
+            of corresponding labels. If 'sample_weight' is not None,
+            the yielded tuples are of the form `(x, y, sample_weight)`.
+            If `y` is None, only the numpy array `x` is returned.
+    """
+    return NumpyArrayIterator(
+        x,
+        y,
+        self,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        sample_weight=sample_weight,
+        seed=seed,
+        data_format=self.data_format,
+        save_to_dir=save_to_dir,
+        save_prefix=save_prefix,
+        save_format=save_format,
+        subset=subset)
+
+  def flow_from_directory(self,
+                          directory,
+                          target_size=(256, 256),
+                          color_mode='rgb',
+                          classes=None,
+                          class_mode='categorical',
+                          batch_size=32,
+                          shuffle=True,
+                          seed=None,
+                          save_to_dir=None,
+                          save_prefix='',
+                          save_format='png',
+                          follow_links=False,
+                          subset=None,
+                          interpolation='nearest'):
+    """Takes the path to a directory & generates batches of augmented data.
+
+    Arguments:
+        directory: string, path to the target directory. It should contain one
+          subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images inside
+          each of the subdirectories directory tree will be included in the
+          generator. See [this script](
+            https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d)
+              for more details.
+        target_size: Tuple of integers `(height, width)`, defaults to `(256,
+          256)`. The dimensions to which all images found will be resized.
+        color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb". Whether
+          the images will be converted to have 1, 3, or 4 channels.
+        classes: Optional list of class subdirectories
+            (e.g. `['dogs', 'cats']`). Default: None. If not provided, the list
+              of classes will be automatically inferred from the subdirectory
+              names/structure under `directory`, where each subdirectory will be
+              treated as a different class (and the order of the classes, which
+              will map to the label indices, will be alphanumeric). The
+              dictionary containing the mapping from class names to class
+              indices can be obtained via the attribute `class_indices`.
+        class_mode: One of "categorical", "binary", "sparse",
+            "input", or None. Default: "categorical".
+            Determines the type of label arrays that are returned: -
+              "categorical" will be 2D one-hot encoded labels, - "binary" will
+              be 1D binary labels, "sparse" will be 1D integer labels, - "input"
+              will be images identical to input images (mainly used to work with
+              autoencoders). - If None, no labels are returned (the generator
+              will only yield batches of image data, which is useful to use with
+              `model.predict_generator()`). Please note that in case of
+              class_mode None, the data still needs to reside in a subdirectory
+              of `directory` for it to work correctly.
+        batch_size: Size of the batches of data (default: 32).
+        shuffle: Whether to shuffle the data (default: True) If set to False,
+          sorts the data in alphanumeric order.
+        seed: Optional random seed for shuffling and transformations.
+        save_to_dir: None or str (default: None). This allows you to optionally
+          specify a directory to which to save the augmented pictures being
+          generated (useful for visualizing what you are doing).
+        save_prefix: Str. Prefix to use for filenames of saved pictures (only
+          relevant if `save_to_dir` is set).
+        save_format: One of "png", "jpeg"
+            (only relevant if `save_to_dir` is set). Default: "png".
+        follow_links: Whether to follow symlinks inside
+            class subdirectories (default: False).
+        subset: Subset of data (`"training"` or `"validation"`) if
+          `validation_split` is set in `ImageDataGenerator`.
+        interpolation: Interpolation method used to resample the image if the
+          target size is different from that of the loaded image. Supported
+          methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. If PIL version
+          1.1.3 or newer is installed, `"lanczos"` is also supported. If PIL
+          version 3.4.0 or newer is installed, `"box"` and `"hamming"` are also
+          supported. By default, `"nearest"` is used.
+
+    Returns:
+        A `DirectoryIterator` yielding tuples of `(x, y)`
+            where `x` is a numpy array containing a batch
+            of images with shape `(batch_size, *target_size, channels)`
+            and `y` is a numpy array of corresponding labels.
+    """
+    return DirectoryIterator(
+        directory,
+        self,
+        target_size=target_size,
+        color_mode=color_mode,
+        classes=classes,
+        class_mode=class_mode,
+        data_format=self.data_format,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        seed=seed,
+        save_to_dir=save_to_dir,
+        save_prefix=save_prefix,
+        save_format=save_format,
+        follow_links=follow_links,
+        subset=subset,
+        interpolation=interpolation)
+
+  def flow_from_dataframe(self,
+                          dataframe,
+                          directory=None,
+                          x_col='filename',
+                          y_col='class',
+                          weight_col=None,
+                          target_size=(256, 256),
+                          color_mode='rgb',
+                          classes=None,
+                          class_mode='categorical',
+                          batch_size=32,
+                          shuffle=True,
+                          seed=None,
+                          save_to_dir=None,
+                          save_prefix='',
+                          save_format='png',
+                          subset=None,
+                          interpolation='nearest',
+                          validate_filenames=True,
+                          **kwargs):
+    """Takes the dataframe and the path to a directory + generates batches.
+
+     The generated batches contain augmented/normalized data.
+
+    **A simple tutorial can be found **[here](
+                                http://bit.ly/keras_flow_from_dataframe).
+
+    Arguments:
+        dataframe: Pandas dataframe containing the filepaths relative to
+          `directory` (or absolute paths if `directory` is None) of the images
+          in a string column. It should include other column/s
+            depending on the `class_mode`: - if `class_mode` is `"categorical"`
+              (default value) it must include the `y_col` column with the
+              class/es of each image. Values in column can be string/list/tuple
+              if a single class or list/tuple if multiple classes. - if
+              `class_mode` is `"binary"` or `"sparse"` it must include the given
+              `y_col` column with class values as strings. - if `class_mode` is
+              `"raw"` or `"multi_output"` it should contain the columns
+              specified in `y_col`. - if `class_mode` is `"input"` or `None` no
+              extra column is needed.
+        directory: string, path to the directory to read images from. If `None`,
+          data in `x_col` column should be absolute paths.
+        x_col: string, column in `dataframe` that contains the filenames (or
+          absolute paths if `directory` is `None`).
+        y_col: string or list, column/s in `dataframe` that has the target data.
+        weight_col: string, column in `dataframe` that contains the sample
+            weights. Default: `None`.
+        target_size: tuple of integers `(height, width)`, default: `(256, 256)`.
+          The dimensions to which all images found will be resized.
+        color_mode: one of "grayscale", "rgb", "rgba". Default: "rgb". Whether
+          the images will be converted to have 1 or 3 color channels.
+        classes: optional list of classes (e.g. `['dogs', 'cats']`). Default is
+          None. If not provided, the list of classes will be automatically
+          inferred from the `y_col`, which will map to the label indices, will
+          be alphanumeric). The dictionary containing the mapping from class
+          names to class indices can be obtained via the attribute
+          `class_indices`.
+        class_mode: one of "binary", "categorical", "input", "multi_output",
+            "raw", sparse" or None. Default: "categorical".
+            Mode for yielding the targets:
+            - `"binary"`: 1D numpy array of binary labels,
+            - `"categorical"`: 2D numpy array of one-hot encoded labels.
+              Supports multi-label output.
+            - `"input"`: images identical to input images (mainly used to work
+              with autoencoders),
+            - `"multi_output"`: list with the values of the different columns,
+            - `"raw"`: numpy array of values in `y_col` column(s),
+            - `"sparse"`: 1D numpy array of integer labels, - `None`, no targets
+              are returned (the generator will only yield batches of image data,
+              which is useful to use in `model.predict_generator()`).
+        batch_size: size of the batches of data (default: 32).
+        shuffle: whether to shuffle the data (default: True)
+        seed: optional random seed for shuffling and transformations.
+        save_to_dir: None or str (default: None). This allows you to optionally
+          specify a directory to which to save the augmented pictures being
+          generated (useful for visualizing what you are doing).
+        save_prefix: str. Prefix to use for filenames of saved pictures (only
+          relevant if `save_to_dir` is set).
+        save_format: one of "png", "jpeg"
+            (only relevant if `save_to_dir` is set). Default: "png".
+        subset: Subset of data (`"training"` or `"validation"`) if
+          `validation_split` is set in `ImageDataGenerator`.
+        interpolation: Interpolation method used to resample the image if the
+          target size is different from that of the loaded image. Supported
+          methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. If PIL version
+          1.1.3 or newer is installed, `"lanczos"` is also supported. If PIL
+          version 3.4.0 or newer is installed, `"box"` and `"hamming"` are also
+          supported. By default, `"nearest"` is used.
+        validate_filenames: Boolean, whether to validate image filenames in
+          `x_col`. If `True`, invalid images will be ignored. Disabling this
+          option can lead to speed-up in the execution of this function.
+          Defaults to `True`.
+        **kwargs: legacy arguments for raising deprecation warnings.
+
+    Returns:
+        A `DataFrameIterator` yielding tuples of `(x, y)`
+        where `x` is a numpy array containing a batch
+        of images with shape `(batch_size, *target_size, channels)`
+        and `y` is a numpy array of corresponding labels.
+    """
+    if 'has_ext' in kwargs:
+      tf_logging.warn(
+          'has_ext is deprecated, filenames in the dataframe have '
+          'to match the exact filenames in disk.', DeprecationWarning)
+    if 'sort' in kwargs:
+      tf_logging.warn(
+          'sort is deprecated, batches will be created in the'
+          'same order than the filenames provided if shuffle'
+          'is set to False.', DeprecationWarning)
+    if class_mode == 'other':
+      tf_logging.warn(
+          '`class_mode` "other" is deprecated, please use '
+          '`class_mode` "raw".', DeprecationWarning)
+      class_mode = 'raw'
+    if 'drop_duplicates' in kwargs:
+      tf_logging.warn(
+          'drop_duplicates is deprecated, you can drop duplicates '
+          'by using the pandas.DataFrame.drop_duplicates method.',
+          DeprecationWarning)
+
+    return DataFrameIterator(
+        dataframe,
+        directory,
+        self,
+        x_col=x_col,
+        y_col=y_col,
+        weight_col=weight_col,
+        target_size=target_size,
+        color_mode=color_mode,
+        classes=classes,
+        class_mode=class_mode,
+        data_format=self.data_format,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        seed=seed,
+        save_to_dir=save_to_dir,
+        save_prefix=save_prefix,
+        save_format=save_format,
+        subset=subset,
+        interpolation=interpolation,
+        validate_filenames=validate_filenames)
+
+
 keras_export('keras.preprocessing.image.random_rotation')(random_rotation)
 keras_export('keras.preprocessing.image.random_shift')(random_shift)
 keras_export('keras.preprocessing.image.random_shear')(random_shear)
diff --git a/tensorflow/python/keras/preprocessing/image_test.py b/tensorflow/python/keras/preprocessing/image_test.py
index d7da420..a577381 100644
--- a/tensorflow/python/keras/preprocessing/image_test.py
+++ b/tensorflow/python/keras/preprocessing/image_test.py
@@ -25,6 +25,9 @@
 import numpy as np
 
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import layers
+from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.preprocessing import image as preprocessing_image
 from tensorflow.python.platform import test
 
@@ -52,7 +55,7 @@
   return [rgb_images, gray_images]
 
 
-class TestImage(test.TestCase):
+class TestImage(keras_parameterized.TestCase):
 
   @test_util.run_v2_only
   def test_smart_resize(self):
@@ -319,14 +322,21 @@
     self.assertEqual(
         len(set(train_iterator.filenames) & set(filenames)), num_training)
 
+    model = sequential.Sequential([layers.Flatten(), layers.Dense(2)])
+    model.compile(optimizer='sgd', loss='mse')
+    model.fit(train_iterator, epochs=1)
+
     shutil.rmtree(tmp_folder)
 
+  @keras_parameterized.run_all_keras_modes
   def test_directory_iterator_with_validation_split_25_percent(self):
     self.directory_iterator_with_validation_split_test_helper(0.25)
 
+  @keras_parameterized.run_all_keras_modes
   def test_directory_iterator_with_validation_split_40_percent(self):
     self.directory_iterator_with_validation_split_test_helper(0.40)
 
+  @keras_parameterized.run_all_keras_modes
   def test_directory_iterator_with_validation_split_50_percent(self):
     self.directory_iterator_with_validation_split_test_helper(0.50)
 
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 3387923..5d58795 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -1021,7 +1021,7 @@
     # self._testWhileLoopWritePackGradients(
     #     dynamic_size=False, dtype=tf.int64)
 
-  @test_util.run_v1_only("b/117943489")
+  @test_util.run_deprecated_v1
   def testSkipEagerWhileLoopDynamicWritePackGradients(self):
     self._testWhileLoopWritePackGradients(
         dynamic_size=True, dtype=dtypes.float32)
@@ -1251,7 +1251,6 @@
       with self.assertRaises(ValueError):
         w1.write(4, c2)
 
-  @test_util.run_v1_only("b/117943489")
   def testUnpackShape(self):
     self._testUnpackShape()
 
@@ -1340,11 +1339,11 @@
       grad = gradients_impl.gradients(ys=[r], xs=[x])
       self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
 
-  @test_util.run_v1_only("b/117943489")
+  @test_util.run_deprecated_v1
   def testSkipEagerTensorArrayUnpackDynamic(self):
     self._testTensorArrayUnpackDynamic()
 
-  @test_util.run_v1_only("b/117943489")
+  @test_util.run_deprecated_v1
   def testSkipEagerTensorArraySplitDynamic(self):
     with self.session(use_gpu=True) as sess:
       ta = tensor_array_ops.TensorArray(
@@ -1422,7 +1421,7 @@
           v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
         ta.stack().eval()
 
-  @test_util.run_v1_only("b/120545219")
+  @test_util.run_deprecated_v1
   def testSkipEagerTensorArrayEvalEmpty(self):
     self._testTensorArrayEvalEmpty()
 
@@ -1445,11 +1444,11 @@
       # first dimension of zero
       self.assertAllEqual([0, 5], self.evaluate(concatenated).shape)
 
-  @test_util.run_v1_only("b/117943489")
+  @test_util.run_deprecated_v1
   def testSkipEagerTensorArrayEvalEmptyWithDefault(self):
     self._testTensorArrayEvalEmptyWithDefault()
 
-  @test_util.run_v1_only("b/117943489")
+  @test_util.run_deprecated_v1
   def testSkipEagerTensorArrayScatterReadAndGradients(self):
     with self.session(use_gpu=True) as session:
       ta = tensor_array_ops.TensorArray(
@@ -1476,7 +1475,7 @@
       self.assertAllEqual([10.0, -10.0], read_vals[1])
       self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
 
-  @test_util.run_v1_only("b/117943489")
+  @test_util.run_deprecated_v1
   def testSkipEagerTensorArrayScatterPartialReadAndGradients(self):
     with self.session(use_gpu=True) as session:
       ta = tensor_array_ops.TensorArray(
diff --git a/tensorflow/python/lib/core/pybind11_status.h b/tensorflow/python/lib/core/pybind11_status.h
index feb9747..3f9991c 100644
--- a/tensorflow/python/lib/core/pybind11_status.h
+++ b/tensorflow/python/lib/core/pybind11_status.h
@@ -69,6 +69,20 @@
   }
 }
 
+inline void MaybeRaiseRegisteredFromStatusWithGIL(
+    const tensorflow::Status& status) {
+  if (!status.ok()) {
+    // Acquire GIL for throwing exception.
+    pybind11::gil_scoped_acquire acquire;
+
+    PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
+                    pybind11::make_tuple(pybind11::none(), pybind11::none(),
+                                         status.error_message())
+                        .ptr());
+    throw pybind11::error_already_set();
+  }
+}
+
 inline void MaybeRaiseFromTFStatus(TF_Status* status) {
   TF_Code code = TF_GetCode(status);
   if (code != TF_OK) {
diff --git a/tensorflow/python/lib/io/file_io_wrapper.cc b/tensorflow/python/lib/io/file_io_wrapper.cc
index de806a9..0a2410b 100644
--- a/tensorflow/python/lib/io/file_io_wrapper.cc
+++ b/tensorflow/python/lib/io/file_io_wrapper.cc
@@ -42,50 +42,65 @@
       py::gil_scoped_release release;
       status = tensorflow::Env::Default()->FileExists(filename);
     }
-    tensorflow::MaybeRaiseRegisteredFromStatus(status);
+    tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
   });
   m.def("DeleteFile", [](const std::string& filename) {
-    tensorflow::MaybeRaiseRegisteredFromStatus(
-        tensorflow::Env::Default()->DeleteFile(filename));
+    py::gil_scoped_release release;
+    tensorflow::Status status =
+        tensorflow::Env::Default()->DeleteFile(filename);
+    tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
   });
   m.def("ReadFileToString", [](const std::string& filename) {
     std::string data;
+    py::gil_scoped_release release;
     const auto status =
         ReadFileToString(tensorflow::Env::Default(), filename, &data);
+    pybind11::gil_scoped_acquire acquire;
     tensorflow::MaybeRaiseRegisteredFromStatus(status);
     return py::bytes(data);
   });
   m.def("WriteStringToFile",
         [](const std::string& filename, tensorflow::StringPiece data) {
-          return WriteStringToFile(tensorflow::Env::Default(), filename, data);
+          py::gil_scoped_release release;
+          const auto status =
+              WriteStringToFile(tensorflow::Env::Default(), filename, data);
+          tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
         });
   m.def("GetChildren", [](const std::string& dirname) {
     std::vector<std::string> results;
+    py::gil_scoped_release release;
     const auto status =
         tensorflow::Env::Default()->GetChildren(dirname, &results);
+    pybind11::gil_scoped_acquire acquire;
     tensorflow::MaybeRaiseRegisteredFromStatus(status);
     return results;
   });
   m.def("GetMatchingFiles", [](const std::string& pattern) {
     std::vector<std::string> results;
+    py::gil_scoped_release release;
     const auto status =
         tensorflow::Env::Default()->GetMatchingPaths(pattern, &results);
+    pybind11::gil_scoped_acquire acquire;
     tensorflow::MaybeRaiseRegisteredFromStatus(status);
     return results;
   });
   m.def("CreateDir", [](const std::string& dirname) {
+    py::gil_scoped_release release;
     const auto status = tensorflow::Env::Default()->CreateDir(dirname);
     if (tensorflow::errors::IsAlreadyExists(status)) {
       return;
     }
-    tensorflow::MaybeRaiseRegisteredFromStatus(status);
+    tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
   });
   m.def("RecursivelyCreateDir", [](const std::string& dirname) {
-    tensorflow::MaybeRaiseRegisteredFromStatus(
-        tensorflow::Env::Default()->RecursivelyCreateDir(dirname));
+    py::gil_scoped_release release;
+    const auto status =
+        tensorflow::Env::Default()->RecursivelyCreateDir(dirname);
+    tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
   });
   m.def("CopyFile",
         [](const std::string& src, const std::string& target, bool overwrite) {
+          py::gil_scoped_release release;
           auto* env = tensorflow::Env::Default();
           tensorflow::Status status;
           if (!overwrite && env->FileExists(target).ok()) {
@@ -93,10 +108,11 @@
           } else {
             status = env->CopyFile(src, target);
           }
-          tensorflow::MaybeRaiseRegisteredFromStatus(status);
+          tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
         });
   m.def("RenameFile",
         [](const std::string& src, const std::string& target, bool overwrite) {
+          py::gil_scoped_release release;
           auto* env = tensorflow::Env::Default();
           tensorflow::Status status;
           if (!overwrite && env->FileExists(target).ok()) {
@@ -104,9 +120,10 @@
           } else {
             status = env->RenameFile(src, target);
           }
-          tensorflow::MaybeRaiseRegisteredFromStatus(status);
+          tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
         });
   m.def("DeleteRecursively", [](const std::string& dirname) {
+    py::gil_scoped_release release;
     tensorflow::int64 undeleted_files;
     tensorflow::int64 undeleted_dirs;
     auto status = tensorflow::Env::Default()->DeleteRecursively(
@@ -115,23 +132,25 @@
       status =
           tensorflow::errors::PermissionDenied("could not fully delete dir");
     }
-    tensorflow::MaybeRaiseRegisteredFromStatus(status);
+    tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
   });
   m.def("IsDirectory", [](const std::string& dirname) {
+    py::gil_scoped_release release;
     const auto status = tensorflow::Env::Default()->IsDirectory(dirname);
     // FAILED_PRECONDITION response means path exists but isn't a dir.
     if (tensorflow::errors::IsFailedPrecondition(status)) {
       return false;
     }
 
-    tensorflow::MaybeRaiseRegisteredFromStatus(status);
+    tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
     return true;
   });
   m.def("HasAtomicMove", [](const std::string& path) {
+    py::gil_scoped_release release;
     bool has_atomic_move;
     const auto status =
         tensorflow::Env::Default()->HasAtomicMove(path, &has_atomic_move);
-    tensorflow::MaybeRaiseRegisteredFromStatus(status);
+    tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
     return has_atomic_move;
   });
 
@@ -141,9 +160,11 @@
       .def_readonly("is_directory", &tensorflow::FileStatistics::is_directory);
 
   m.def("Stat", [](const std::string& filename) {
+    py::gil_scoped_release release;
     std::unique_ptr<tensorflow::FileStatistics> self(
         new tensorflow::FileStatistics);
     const auto status = tensorflow::Env::Default()->Stat(filename, self.get());
+    py::gil_scoped_acquire acquire;
     tensorflow::MaybeRaiseRegisteredFromStatus(status);
     return self.release();
   });
@@ -151,66 +172,83 @@
   using tensorflow::WritableFile;
   py::class_<WritableFile>(m, "WritableFile")
       .def(py::init([](const std::string& filename, const std::string& mode) {
+        py::gil_scoped_release release;
         auto* env = tensorflow::Env::Default();
         std::unique_ptr<WritableFile> self;
         const auto status = mode.find("a") == std::string::npos
                                 ? env->NewWritableFile(filename, &self)
                                 : env->NewAppendableFile(filename, &self);
+        py::gil_scoped_acquire acquire;
         tensorflow::MaybeRaiseRegisteredFromStatus(status);
         return self.release();
       }))
       .def("append",
            [](WritableFile* self, tensorflow::StringPiece data) {
-             tensorflow::MaybeRaiseRegisteredFromStatus(self->Append(data));
+             const auto status = self->Append(data);
+             tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
            })
       // TODO(slebedev): Make WritableFile::Tell const and change self
       // to be a reference.
       .def("tell",
            [](WritableFile* self) {
              tensorflow::int64 pos = -1;
+             py::gil_scoped_release release;
              const auto status = self->Tell(&pos);
-             tensorflow::MaybeRaiseRegisteredFromStatus(status);
+             tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
              return pos;
            })
       .def("flush",
            [](WritableFile* self) {
-             tensorflow::MaybeRaiseRegisteredFromStatus(self->Flush());
+             py::gil_scoped_release release;
+             tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(self->Flush());
            })
       .def("close", [](WritableFile* self) {
-        tensorflow::MaybeRaiseRegisteredFromStatus(self->Close());
+        py::gil_scoped_release release;
+        tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(self->Close());
       });
 
   using tensorflow::io::BufferedInputStream;
   py::class_<BufferedInputStream>(m, "BufferedInputStream")
       .def(py::init([](const std::string& filename, size_t buffer_size) {
+        py::gil_scoped_release release;
         std::unique_ptr<tensorflow::RandomAccessFile> file;
         const auto status =
             tensorflow::Env::Default()->NewRandomAccessFile(filename, &file);
-        tensorflow::MaybeRaiseRegisteredFromStatus(status);
+        tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
         std::unique_ptr<tensorflow::io::RandomAccessInputStream> input_stream(
             new tensorflow::io::RandomAccessInputStream(file.release(),
                                                         /*owns_file=*/true));
+        py::gil_scoped_acquire acquire;
         return new BufferedInputStream(input_stream.release(), buffer_size,
                                        /*owns_input_stream=*/true);
       }))
       .def("read",
            [](BufferedInputStream* self, tensorflow::int64 bytes_to_read) {
+             py::gil_scoped_release release;
              tensorflow::tstring result;
              const auto status = self->ReadNBytes(bytes_to_read, &result);
              if (!status.ok() && !tensorflow::errors::IsOutOfRange(status)) {
                result.clear();
                tensorflow::MaybeRaiseRegisteredFromStatus(status);
              }
+             py::gil_scoped_acquire acquire;
              return py::bytes(result);
            })
       .def("readline",
            [](BufferedInputStream* self) {
-             return py::bytes(self->ReadLineAsString());
+             py::gil_scoped_release release;
+             auto output = self->ReadLineAsString();
+             py::gil_scoped_acquire acquire;
+             return py::bytes(output);
            })
       .def("seek",
            [](BufferedInputStream* self, tensorflow::int64 pos) {
-             tensorflow::MaybeRaiseRegisteredFromStatus(self->Seek(pos));
+             py::gil_scoped_release release;
+             tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(self->Seek(pos));
            })
-      .def("tell", [](BufferedInputStream* self) { return self->Tell(); });
+      .def("tell", [](BufferedInputStream* self) {
+        py::gil_scoped_release release;
+        return self->Tell();
+      });
 }
 }  // namespace
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 33aac84..1cb6fdb 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -39,6 +39,7 @@
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_array_ops import *
 from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse  # pylint: disable=unused-import
+from tensorflow.python.types import core
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import dispatch
 from tensorflow.python.util import nest
@@ -1381,13 +1382,13 @@
   if context.executing_eagerly():
     # NOTE: Fast path when all the items are tensors, this doesn't do any type
     # checking.
-    if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple):
+    if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
       return gen_array_ops.pack(list_or_tuple, name=name)
   must_pack = False
   converted_elems = []
   with ops.name_scope(name) as scope:
     for i, elem in enumerate(list_or_tuple):
-      if ops.is_dense_tensor_like(elem):
+      if isinstance(elem, core.Tensor):
         if dtype is not None and elem.dtype.base_dtype != dtype:
           raise TypeError("Cannot convert a list containing a tensor of dtype "
                           "%s to %s (Tensor is: %r)" %
@@ -1396,7 +1397,7 @@
         must_pack = True
       elif isinstance(elem, (list, tuple)):
         converted_elem = _autopacking_helper(elem, dtype, str(i))
-        if ops.is_dense_tensor_like(converted_elem):
+        if isinstance(converted_elem, core.Tensor):
           must_pack = True
         converted_elems.append(converted_elem)
       else:
@@ -1404,7 +1405,7 @@
     if must_pack:
       elems_as_tensors = []
       for i, elem in enumerate(converted_elems):
-        if ops.is_dense_tensor_like(elem):
+        if isinstance(elem, core.Tensor):
           elems_as_tensors.append(elem)
         else:
           # NOTE(mrry): This is inefficient, but it enables us to
@@ -1429,7 +1430,7 @@
     such object exists.
   """
   for elem in list_or_tuple:
-    if ops.is_dense_tensor_like(elem):
+    if isinstance(elem, core.Tensor):
       return elem.dtype.base_dtype
     elif isinstance(elem, (list, tuple)):
       maybe_dtype = _get_dtype_from_nested_lists(elem)
@@ -1441,7 +1442,7 @@
 def _cast_nested_seqs_to_dtype(dtype):
 
   def _maybe_cast(elem):
-    if ops.is_dense_tensor_like(elem):
+    if isinstance(elem, core.Tensor):
       if dtype != elem.dtype.base_dtype:
         elem = gen_math_ops.cast(elem, dtype)
     return elem
@@ -1455,7 +1456,7 @@
 
 def _should_not_autopack(v):
   # The condition we really want is
-  #    ops.is_dense_tensor_like(...)
+  #    any(isinstance(elem, core.Tensor))
   # but it is >5x slower due to abc.ABCMeta.__instancecheck__.
   # pylint: disable=unidiomatic-typecheck
   # TODO(slebedev): add nest.all?
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index de5be20..248c57c 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -45,6 +45,7 @@
 # pylint: enable=wildcard-import
 from tensorflow.python.platform import device_context
 from tensorflow.python.util import deprecation
+from tensorflow.python.util import dispatch
 from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.deprecation import deprecated_args
 from tensorflow.python.util.deprecation import deprecated_argument_lookup
@@ -4513,6 +4514,7 @@
 
 
 @tf_export(v1=["nn.dropout"])
+@dispatch.add_dispatch_support
 @deprecation.deprecated_args(None, "Please use `rate` instead of `keep_prob`. "
                              "Rate should be set to `rate = 1 - keep_prob`.",
                              "keep_prob")
@@ -4567,6 +4569,7 @@
 
 
 @tf_export("nn.dropout", v1=[])
+@dispatch.add_dispatch_support
 def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
   """Computes dropout: randomly sets elements to zero to prevent overfitting.
 
diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py
index dd5bd78..f13bed0 100644
--- a/tensorflow/python/ops/ragged/ragged_dispatch.py
+++ b/tensorflow/python/ops/ragged/ragged_dispatch.py
@@ -30,6 +30,7 @@
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_bitwise_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variables
@@ -453,6 +454,26 @@
                                                      num_partitions, name)
   return [result[i] for i in range(num_partitions)]
 
+
+def _ragged_nn_dropout_v1(x, keep_prob=None, noise_shape=None, seed=None,
+                          name=None, rate=None):
+  if noise_shape is not None:
+    raise ValueError('noise_shape is not supported yet for RaggedTensor x')
+  with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
+    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
+    return x.with_flat_values(nn_ops.dropout(x.flat_values, keep_prob=keep_prob,
+                                             seed=seed, rate=rate))
+
+
+def _ragged_nn_dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
+  if noise_shape is not None:
+    raise ValueError('noise_shape is not supported yet for RaggedTensor x')
+  with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
+    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
+    return x.with_flat_values(nn_ops.dropout_v2(x.flat_values, rate=rate,
+                                                seed=seed))
+
+
 # (original_op, ragged_op, ragged_args)
 _RAGGED_DISPATCH_OPS = [
     (array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
@@ -497,6 +518,8 @@
     (math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
     (math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
     (math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
+    (nn_ops.dropout, _ragged_nn_dropout_v1, ['x']),
+    (nn_ops.dropout_v2, _ragged_nn_dropout_v2, ['x']),
 ]
 
 
diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
index 0ce9a6f..60d9f6c 100644
--- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py
+++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
@@ -32,6 +32,7 @@
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_bitwise_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.ops.ragged import ragged_dispatch
@@ -232,6 +233,10 @@
           {'op': array_ops.check_numerics,
            'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
            'message': 'check-numerics'},
+          {'op': nn_ops.dropout,
+           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
+           'rate': 0.5,
+           'seed': 1},
       ]
       )  # pyformat: disable
   def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
@@ -820,7 +825,8 @@
         'strings.substr', 'strings.to_hash_bucket_fast',
         'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
         'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
-        'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse'
+        'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse',
+        'nn.dropout',
     ]
 
     # Ops that should be listed as supported in v1 only.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index f99f886..d8a7765 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -49,6 +49,7 @@
 from tensorflow.python.ops.gen_resource_variable_ops import *
 # pylint: enable=wildcard-import
 from tensorflow.python.training.tracking import base as trackable
+from tensorflow.python.types import core
 from tensorflow.python.util import compat
 from tensorflow.python.util.deprecation import deprecated
 
@@ -330,7 +331,7 @@
     tape.variable_accessed(variable)
 
 
-class BaseResourceVariable(variables.VariableV1):
+class BaseResourceVariable(variables.VariableV1, core.Tensor):
   """A python variable from an existing handle."""
 
   # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
@@ -1830,7 +1831,6 @@
 # allowing instances of the class to be used as tensors.
 ops.register_tensor_conversion_function(BaseResourceVariable,
                                         _dense_var_to_tensor)
-ops.register_dense_tensor_like_type(BaseResourceVariable)
 
 
 class _UnreadVariable(BaseResourceVariable):
@@ -1955,9 +1955,6 @@
     return self._parent_op
 
 
-ops.register_dense_tensor_like_type(_UnreadVariable)
-
-
 @ops.RegisterGradient("ReadVariableOp")
 def _ReadGrad(_, grad):
   """Gradient for read op."""
diff --git a/tensorflow/python/ops/signal/mel_ops.py b/tensorflow/python/ops/signal/mel_ops.py
index aa07691..b95876b 100644
--- a/tensorflow/python/ops/signal/mel_ops.py
+++ b/tensorflow/python/ops/signal/mel_ops.py
@@ -128,8 +128,6 @@
       # S has shape [..., num_spectrogram_bins].
       # M has shape [..., num_mel_bins].
       M = tf.tensordot(S, A, 1)
-      # tf.tensordot does not support shape inference for this case yet.
-      M.set_shape(S.shape[:-1].concatenate(A.shape[-1:]))
 
   Args:
     num_mel_bins: Python int. How many bands in the resulting mel spectrum.
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index d65cd23..81c3f9a 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -42,6 +42,7 @@
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.types import core
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import function_utils
 from tensorflow.python.util import tf_contextlib
@@ -1000,7 +1001,7 @@
     return initializer, initializing_from_value
 
 
-class _LazyEvalTensor(object):
+class _LazyEvalTensor(core.Tensor):
   """A Tensor-like object that only evaluates its thunk when used."""
 
   def __init__(self, thunk):
@@ -1069,8 +1070,6 @@
     lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0])  # pylint: disable=protected-access
     )
 
-ops.register_dense_tensor_like_type(_LazyEvalTensor)
-
 
 # To stop regularization, use this regularizer
 @tf_export(v1=["no_regularizer"])
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 1080778..d3df065 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -47,6 +47,7 @@
 from tensorflow.python.util.deprecation import deprecated
 from tensorflow.python.util.deprecation import deprecated_args
 from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.types import core
 
 
 def default_variable_creator(_, **kwds):
@@ -264,6 +265,7 @@
 
 
 @tf_export("Variable", v1=[])
+# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
 class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
   """See the [variable guide](https://tensorflow.org/guide/variable).
 
@@ -1551,7 +1553,7 @@
 
 
 # TODO(apassos): do not repeat all comments here
-class RefVariable(VariableV1):
+class RefVariable(VariableV1, core.Tensor):
   """Ref-based implementation of variables."""
 
   def __init__(
@@ -3032,7 +3034,6 @@
 # allowing instances of the class to be used as tensors.
 ops.register_tensor_conversion_function(RefVariable,
                                         RefVariable._TensorConversionFunction)  # pylint: disable=protected-access
-ops.register_dense_tensor_like_type(RefVariable)
 
 
 @tf_export(v1=["global_variables"])
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index e5ca608..b6565f5 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -226,6 +226,7 @@
     deps = [
         "//tensorflow/python:util",
         "//tensorflow/python/profiler/internal:_pywrap_traceme",
+        "//tensorflow/python/types",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py
index aeca90b..b36a1f2 100644
--- a/tensorflow/python/saved_model/function_deserialization.py
+++ b/tensorflow/python/saved_model/function_deserialization.py
@@ -332,6 +332,11 @@
 
     functions[fdef.signature.name] = func
     renamed_functions[func.name] = func
+    if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
+      # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
+      # is fixed. Currently it's leaking memory to maintain bug compatibility
+      # with previous behavior.
+      func.add_to_graph(ops.get_default_graph())
 
   return functions
 
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 42e971d..0f635b6 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -178,7 +178,7 @@
     spec = struct_coder.decode_proto(spec_proto)
     components = [_get_tensor(component.name) for component in
                   tensor_info.composite_tensor.components]
-    return spec._from_components(components)  # pylint: disable=protected-access
+    return nest.pack_sequence_as(spec, components, expand_composites=True)
   else:
     raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
 
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index fa07a92..d1848f3 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -828,7 +828,7 @@
   ...     end_learning_rate=0.0)
   >>> wordpiece_table_config = TableConfig(
   ...   vocabulary_size=119547,
-  ...   dimension=768,
+  ...   dimension=256,
   ...   learning_rate_fn=learning_rate_fn)
   >>> wordpiece_feature_config = FeatureConfig(
   ...   table_id='bert/embeddings/word_embeddings',
@@ -846,11 +846,11 @@
   ...  batch_size=128,
   ...  mode=TRAINING,
   ...  optimization_parameters=optimization_parameters,
-  ...  device_config=DeviceConfig(
-  ...    num_cores=64, num_hosts=4, job_name='tpu_worker'))
+  ...  master='')
   >>> with tf.Graph().as_default():
   ...   init_tpu_op = tf.compat.v1.tpu.initialize_system(
-  ...     embedding_config=tpu_embedding.config_proto, job='tpu_worker')
+  ...     embedding_config=tpu_embedding.config_proto)
+  ...   tf.compat.v1.Session().run(init_tpu_op)
   """
 
   # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD
index f35ca7f..e93bf5c 100644
--- a/tensorflow/python/types/BUILD
+++ b/tensorflow/python/types/BUILD
@@ -27,6 +27,9 @@
         "internal.py",
     ],
     srcs_version = "PY2AND3",
-    visibility = ["//tensorflow:__subpackages__"],
+    visibility = [
+        "//tensorflow:__subpackages__",
+        "//tensorflow:types_whitelist",
+    ],
     deps = [],
 )
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 61d6656..f56330b 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -193,10 +193,10 @@
         "//conditions:default": otherwise,
     })
 
-def if_ios(a):
+def if_ios(a, otherwise = []):
     return select({
         clean_dep("//tensorflow:ios"): a,
-        "//conditions:default": [],
+        "//conditions:default": otherwise,
     })
 
 def if_ios_x86_64(a):
diff --git a/tensorflow/tools/android/inference_interface/BUILD b/tensorflow/tools/android/inference_interface/BUILD
index cbd161f..fb3ab00 100644
--- a/tensorflow/tools/android/inference_interface/BUILD
+++ b/tensorflow/tools/android/inference_interface/BUILD
@@ -34,7 +34,7 @@
     copts = tf_copts(),
     visibility = ["//visibility:public"],
     deps = [
-        "//tensorflow/core:android_tensorflow_lib_lite",
+        "//tensorflow/core:portable_tensorflow_lib_lite",
         "//tensorflow/java/src/main/native",
     ],
     alwayslink = 1,
@@ -83,7 +83,7 @@
     ],
     deps = [
         ":android_tensorflow_inference_jni",
-        "//tensorflow/core:android_tensorflow_lib",
+        "//tensorflow/core:portable_tensorflow_lib",
         LINKER_SCRIPT,
     ],
 )
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
index 4a30fae..9315973 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
@@ -2,6 +2,7 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
   is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
+  is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
   is_instance: "<type \'object\'>"
   member {
     name: "OVERLOADABLE_OPERATORS"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
index 4a30fae..9315973 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
@@ -2,6 +2,7 @@
 tf_class {
   is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
   is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
+  is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
   is_instance: "<type \'object\'>"
   member {
     name: "OVERLOADABLE_OPERATORS"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
index 0b49aa9..e59c78c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
@@ -68,4 +68,8 @@
     name: "save_img"
     argspec: "args=[\'path\', \'x\', \'data_format\', \'file_format\', \'scale\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'True\'], "
   }
+  member_method {
+    name: "smart_resize"
+    argspec: "args=[\'x\', \'size\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'bilinear\'], "
+  }
 }
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython
index 353d946..9c85091 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython
@@ -55,17 +55,17 @@
 COPY install/install_deb_packages.sh /install/
 RUN /install/install_deb_packages.sh
 
-# Install patchelf to facilitate the creation of manylinux2010 whls.
-COPY install/install_patchelf.sh /install/
-RUN /install/install_patchelf.sh
-
-# Install additional dependencies to build Python from source.
+# Install additional packages needed for this image:
+# - dependencies to build Python from source
+# - patchelf, as it is required by auditwheel
 RUN apt-get update && apt-get install -y \
-    libncurses5-dev \
+    libbz2-dev \
+    libffi-dev \
     libgdbm-dev \
+    libncurses5-dev \
     libnss3-dev \
     libreadline-dev \
-    libffi-dev \
+    patchelf \
       && \
     rm -rf /var/lib/apt/lists/*
 
@@ -86,9 +86,6 @@
 RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.6"
 RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.7"
 
-# Install auditwheel to create manylinux2010 compliant binaries
-RUN pip3 install auditwheel
-
 ENV CLANG_VERSION="r42cab985fd95ba4f3f290e7bb26b93805edb447d"
 COPY install/install_latest_clang.sh /install/
 RUN /install/install_latest_clang.sh
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh b/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh
index d9953db..81e5f2b 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh
@@ -26,6 +26,7 @@
 fi
 
 PACKAGES=(
+  "auditwheel"
   "wheel"
   "setuptools"
   "virtualenv"
diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh
index a6ef52b..bb40042 100644
--- a/tensorflow/tools/ci_build/release/common.sh
+++ b/tensorflow/tools/ci_build/release/common.sh
@@ -146,6 +146,7 @@
   ${PIP_CMD} install --user --upgrade attrs
   ${PIP_CMD} install --user --upgrade tf-estimator-nightly
   ${PIP_CMD} install --user --upgrade "future>=0.17.1"
+  ${PIP_CMD} install --user --upgrade wrapt
   # LINT.ThenChange(:ubuntu_16_pip_installations)
 }
 
@@ -178,6 +179,7 @@
   "${PIP_CMD}" install PyYAML==3.13 --user
   "${PIP_CMD}" install --user --upgrade tf-estimator-nightly
   "${PIP_CMD}" install --user --upgrade tb-nightly
+  "${PIP_CMD}" install --user --upgrade wrapt
   # LINT.ThenChange(:ubuntu_pip_installations)
 }
 
@@ -219,6 +221,7 @@
   ${SUDO_CMD} ${PIP_CMD} install --upgrade tb-nightly
   ${PIP_CMD} install --user --upgrade attrs
   ${PIP_CMD} install --user --upgrade tf-estimator-nightly
+  ${PIP_CMD} install --user --upgrade wrapt
   ${PIP_CMD} install --user --upgrade "future>=0.17.1"
 }
 
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index c092f21..c0442a5 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -2,6 +2,7 @@
 #   Doc generator
 
 load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
 
 package(
     default_visibility = ["//tensorflow:__subpackages__"],
@@ -22,6 +23,7 @@
 py_test(
     name = "tf_doctest",
     srcs = ["tf_doctest.py"],
+    args = ["--module_prefix_skip=tpu.,distribute.tpu_strategy"],
     python_version = "PY3",
     tags = [
         "no_oss_py2",
@@ -40,6 +42,28 @@
     ],
 )
 
+tpu_py_test(
+    name = "tf_doctest_tpu",
+    srcs = ["tf_doctest.py"],
+    args = ["--module=tpu.,distribute.tpu_strategy"],
+    disable_experimental = True,
+    disable_v3 = True,
+    main = "tf_doctest.py",
+    python_version = "PY3",
+    tags = [
+        "no_oss",
+        "noasan",
+        "nomsan",
+        "notsan",
+    ],
+    deps = [
+        ":tf_doctest_lib",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/preprocessing",
+        "//third_party/py/numpy",
+    ],
+)
+
 py_test(
     name = "tf_doctest_test",
     srcs = ["tf_doctest_test.py"],
diff --git a/tensorflow/tools/docs/tf_doctest.py b/tensorflow/tools/docs/tf_doctest.py
index 1962465..fc81d33 100644
--- a/tensorflow/tools/docs/tf_doctest.py
+++ b/tensorflow/tools/docs/tf_doctest.py
@@ -42,7 +42,9 @@
 
 FLAGS = flags.FLAGS
 
-flags.DEFINE_string('module', None, 'A specific module to run doctest on.')
+flags.DEFINE_list('module', [], 'A list of specific module to run doctest on.')
+flags.DEFINE_list('module_prefix_skip', [],
+                  'A list of modules to ignore when resolving modules.')
 flags.DEFINE_boolean('list', None,
                      'List all the modules in the core package imported.')
 flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
@@ -50,6 +52,7 @@
 flags.mark_flags_as_mutual_exclusive(['module', 'file'])
 flags.mark_flags_as_mutual_exclusive(['list', 'file'])
 
+# Both --module and --module_prefix_skip are relative to PACKAGE.
 PACKAGE = 'tensorflow.python.'
 
 
@@ -68,23 +71,24 @@
   return tf_modules
 
 
-def filter_on_submodules(all_modules, submodule):
-  """Filters all the modules based on the module flag.
+def filter_on_submodules(all_modules, submodules):
+  """Filters all the modules based on the modules flag.
 
   The module flag has to be relative to the core package imported.
-  For example, if `submodule=keras.layers` then, this function will return
+  For example, if `module=keras.layers` then, this function will return
   all the modules in the submodule.
 
   Args:
     all_modules: All the modules in the core package.
-    submodule: Submodule to filter from all the modules.
+    submodules: Submodules to filter from all the modules.
 
   Returns:
     All the modules in the submodule.
   """
 
   filtered_modules = [
-      mod for mod in all_modules if PACKAGE + submodule in mod.__name__
+      mod for mod in all_modules
+      if any(PACKAGE + submodule in mod.__name__ for submodule in submodules)
   ]
   return filtered_modules
 
@@ -140,6 +144,9 @@
     tf_modules = get_module_and_inject_docstring(FLAGS.file)
 
   for module in tf_modules:
+    if any(module.__name__.startswith(PACKAGE + prefix)
+           for prefix in FLAGS.module_prefix_skip):
+      continue
     testcase = TfTestCase()
     tests.addTests(
         doctest.DocTestSuite(
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 5b479c5..fe548fd 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -162,8 +162,8 @@
         print("path_prefix was specified to tf_workspace but is no longer used " +
               "and will be removed in the future.")
 
-    TFRT_COMMIT = "0bad623e8d99ace05f7f60e9e7f8b53ec813d66a"
-    TFRT_SHA256 = "d002429866d2d824a80dcf6c1602a15398412bc01324200d371c55b13b9a4b27"
+    TFRT_COMMIT = "341ba0448c117af4e29ae3911141265ee8e57860"
+    TFRT_SHA256 = "27716458f8ca7d91fc2d0f681127dbdd478eea78d6da5153c51b4696ebd14d55"
     TFRT_URLS = [
         "http://mirror.tensorflow.org/github.com/tensorflow/runtime/archive/{commit}.zip".format(commit = TFRT_COMMIT),
         "https://github.com/tensorflow/runtime/archive/{commit}.zip".format(commit = TFRT_COMMIT),
@@ -679,8 +679,8 @@
     )
 
     # Check out LLVM and MLIR from llvm-project.
-    LLVM_COMMIT = "c5e0967e4cf0f1337bec772949e6cede4c01354b"
-    LLVM_SHA256 = "5d8dbddd78fbc1c08825b178aff0a0f04722d83280eb93be55a174391f1885ce"
+    LLVM_COMMIT = "123bee602a260150ff55c74287f583a67ee78f36"
+    LLVM_SHA256 = "313ec75e47ea3f128724a61b8b6b45b7d305ba2ae57a5084b4bf1f881b4ec8f2"
     LLVM_URLS = [
         "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
         "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
index f3b2ae6..303339e 100755
--- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
@@ -53,13 +53,6 @@
 PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
 NVCC_VERSION = '%{cuda_version}'
 
-
-# TODO(amitpatankar): Benchmark enabling all capabilities by default.
-# Environment variable for supported TF CUDA Compute Capabilities
-# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
-CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
-DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
-
 def Log(s):
   print('gpus/crosstool: {0}'.format(s))
 
@@ -78,7 +71,8 @@
   """
 
   parser = ArgumentParser()
-  parser.add_argument('-' + option, nargs='*', action='append')
+  parser.add_argument(option, nargs='*', action='append')
+  option = option.lstrip('-').replace('-', '_')
   args, _ = parser.parse_known_args(argv)
   if not args or not vars(args)[option]:
     return []
@@ -180,17 +174,17 @@
 
   host_compiler_options = GetHostCompilerOptions(argv)
   nvcc_compiler_options = GetNvccOptions(argv)
-  opt_option = GetOptionValue(argv, 'O')
-  m_options = GetOptionValue(argv, 'm')
+  opt_option = GetOptionValue(argv, '-O')
+  m_options = GetOptionValue(argv, '-m')
   m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
-  include_options = GetOptionValue(argv, 'I')
-  out_file = GetOptionValue(argv, 'o')
-  depfiles = GetOptionValue(argv, 'MF')
-  defines = GetOptionValue(argv, 'D')
+  include_options = GetOptionValue(argv, '-I')
+  out_file = GetOptionValue(argv, '-o')
+  depfiles = GetOptionValue(argv, '-MF')
+  defines = GetOptionValue(argv, '-D')
   defines = ''.join([' -D' + define for define in defines])
-  undefines = GetOptionValue(argv, 'U')
+  undefines = GetOptionValue(argv, '-U')
   undefines = ''.join([' -U' + define for define in undefines])
-  std_options = GetOptionValue(argv, 'std')
+  std_options = GetOptionValue(argv, '-std')
   # Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang.
   nvcc_allowed_std_options = ["c++03", "c++11", "c++14"]
   std_options = ''.join([' -std=' + define
@@ -198,7 +192,7 @@
 
   # The list of source files get passed after the -c option. I don't know of
   # any other reliable way to just get the list of source files to be compiled.
-  src_files = GetOptionValue(argv, 'c')
+  src_files = GetOptionValue(argv, '-c')
 
   # Pass -w through from host to nvcc, but don't do anything fancier with
   # warnings-related flags, since they're not necessarily the same across
@@ -224,13 +218,12 @@
   srcs = ' '.join(src_files)
   out = ' -o ' + out_file[0]
 
-  supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
   nvccopts = '-D_FORCE_INLINES '
-  for capability in supported_cuda_compute_capabilities:
-    capability = capability.replace('.', '')
+  for capability in GetOptionValue(argv, "--cuda-gpu-arch"):
+    capability = capability[len('sm_'):]
     nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
         capability, capability, capability)
-  nvccopts += ' ' + nvcc_compiler_options
+  nvccopts += nvcc_compiler_options
   nvccopts += undefines
   nvccopts += defines
   nvccopts += std_options
@@ -272,6 +265,7 @@
   if args.x and args.x[0] == 'cuda':
     if args.cuda_log: Log('-x cuda')
     leftover = [pipes.quote(s) for s in leftover]
+    args.cuda_log = True
     if args.cuda_log: Log('using nvcc')
     return InvokeNvcc(leftover, log=args.cuda_log)
 
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
index 46e8aef..c10fb82 100644
--- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
@@ -37,13 +37,6 @@
 NVCC_PATH = '%{nvcc_path}'
 NVCC_VERSION = '%{cuda_version}'
 NVCC_TEMP_DIR = "%{nvcc_tmp_dir}"
-DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
-
-# Taken from environment variable for supported TF CUDA Compute Capabilities
-# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
-supported_cuda_compute_capabilities = os.environ.get(
-    'TF_CUDA_COMPUTE_CAPABILITIES',
-    DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
 
 def Log(s):
   print('gpus/crosstool: {0}'.format(s))
@@ -53,7 +46,7 @@
   """Extract the list of values for option from options.
 
   Args:
-    option: The option whose value to extract, without the leading '/'.
+    option: The option whose value to extract.
 
   Returns:
     1. A list of values, either directly following the option,
@@ -62,10 +55,11 @@
     2. The leftover options.
   """
 
-  parser = ArgumentParser(prefix_chars='/')
-  parser.add_argument('/' + option, nargs='*', action='append')
+  parser = ArgumentParser(prefix_chars='-/')
+  parser.add_argument(option, nargs='*', action='append')
+  option = option.lstrip('-/').replace('-', '_')
   args, leftover = parser.parse_known_args(argv)
-  if args and vars(args)[option]:
+  if args and vars(args).get(option):
     return (sum(vars(args)[option], []), leftover)
   return ([], leftover)
 
@@ -122,18 +116,18 @@
 
   nvcc_compiler_options, argv = GetNvccOptions(argv)
 
-  opt_option, argv = GetOptionValue(argv, 'O')
+  opt_option, argv = GetOptionValue(argv, '/O')
   opt = ['-g']
   if (len(opt_option) > 0 and opt_option[0] != 'd'):
     opt = ['-O2']
 
-  include_options, argv = GetOptionValue(argv, 'I')
+  include_options, argv = GetOptionValue(argv, '/I')
   includes = ["-I " + include for include in include_options]
 
-  defines, argv = GetOptionValue(argv, 'D')
+  defines, argv = GetOptionValue(argv, '/D')
   defines = ['-D' + define for define in defines]
 
-  undefines, argv = GetOptionValue(argv, 'U')
+  undefines, argv = GetOptionValue(argv, '/U')
   undefines = ['-U' + define for define in undefines]
 
   # The rest of the unrecognized options should be passed to host compiler
@@ -142,10 +136,10 @@
   m_options = ["-m64"]
 
   nvccopts = ['-D_FORCE_INLINES']
-  for capability in supported_cuda_compute_capabilities:
-    capability = capability.replace('.', '')
-    nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % (
-        capability, capability, capability)]
+  for capability in GetOptionValue(argv, "--cuda-gpu-arch"):
+    capability = capability[len('sm_'):]
+    nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
+        capability, capability, capability)
   nvccopts += nvcc_compiler_options
   nvccopts += undefines
   nvccopts += defines
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 545aeeb..c587f11 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -840,10 +840,7 @@
         "--cuda-gpu-arch=sm_" + cap.replace(".", "")
         for cap in compute_capabilities
     ]
-
-    # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
-    # TODO(csigg): Make this consistent with cuda clang and pass unconditionally.
-    return "if_cuda_clang(%s)" % str(capability_flags)
+    return str(capability_flags)
 
 def _tpl_path(repository_ctx, filename):
     return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
@@ -1092,9 +1089,6 @@
             "%{cuda_version}": cuda_config.cuda_version,
             "%{nvcc_path}": nvcc_path,
             "%{gcc_host_compiler_path}": str(cc),
-            "%{cuda_compute_capabilities}": ", ".join(
-                ["\"%s\"" % c for c in cuda_config.compute_capabilities],
-            ),
             "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
         }
         repository_ctx.template(
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 925fad7..75b32c7 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -297,9 +297,9 @@
 )
 
 filegroup(
-    name = "LoopOpsTdFiles",
+    name = "SCFTdFiles",
     srcs = [
-        "include/mlir/Dialect/LoopOps/LoopOps.td",
+        "include/mlir/Dialect/SCF/SCFOps.td",
         "include/mlir/Interfaces/ControlFlowInterfaces.td",
         "include/mlir/Interfaces/LoopLikeInterface.td",
         "include/mlir/Interfaces/SideEffects.td",
@@ -308,26 +308,26 @@
 )
 
 gentbl(
-    name = "LoopOpsIncGen",
+    name = "SCFIncGen",
     strip_include_prefix = "include",
     tbl_outs = [
         (
             "-gen-op-decls",
-            "include/mlir/Dialect/LoopOps/LoopOps.h.inc",
+            "include/mlir/Dialect/SCF/SCFOps.h.inc",
         ),
         (
             "-gen-op-defs",
-            "include/mlir/Dialect/LoopOps/LoopOps.cpp.inc",
+            "include/mlir/Dialect/SCF/SCFOps.cpp.inc",
         ),
         (
             "-gen-dialect-decls",
-            "include/mlir/Dialect/LoopOps/LoopOpsDialect.h.inc",
+            "include/mlir/Dialect/SCF/SCFOpsDialect.h.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/LoopOps/LoopOps.td",
+    td_file = "include/mlir/Dialect/SCF/SCFOps.td",
     td_srcs = [
-        ":LoopOpsTdFiles",
+        ":SCFTdFiles",
     ],
 )
 
@@ -337,30 +337,30 @@
     tbl_outs = [
         (
             "-gen-pass-decls",
-            "include/mlir/Dialect/LoopOps/Passes.h.inc",
+            "include/mlir/Dialect/SCF/Passes.h.inc",
         ),
     ],
     tblgen = ":mlir-tblgen",
-    td_file = "include/mlir/Dialect/LoopOps/Passes.td",
+    td_file = "include/mlir/Dialect/SCF/Passes.td",
     td_srcs = [
         ":PassBaseTdFiles",
     ],
 )
 
 cc_library(
-    name = "LoopOpsTransforms",
+    name = "SCFTransforms",
     srcs = glob([
-        "lib/Dialect/LoopOps/Transforms/*.cpp",
-        "lib/Dialect/LoopOps/Transforms/*.h",
+        "lib/Dialect/SCF/Transforms/*.cpp",
+        "lib/Dialect/SCF/Transforms/*.h",
     ]),
-    hdrs = ["include/mlir/Dialect/LoopOps/Passes.h"],
+    hdrs = ["include/mlir/Dialect/SCF/Passes.h"],
     includes = ["include"],
     deps = [
         ":Affine",
         ":IR",
-        ":LoopOps",
         ":LoopPassIncGen",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Transforms",
         "@llvm-project//llvm:support",
@@ -521,8 +521,8 @@
         ":AffinePassIncGen",
         ":Analysis",
         ":IR",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":Transforms",
@@ -559,8 +559,8 @@
         ":Affine",
         ":ConversionPassIncGen",
         ":IR",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":Transforms",
@@ -588,17 +588,17 @@
 )
 
 cc_library(
-    name = "LoopOps",
+    name = "SCFDialect",
     srcs = glob(
         [
-            "lib/Dialect/LoopOps/*.cpp",
-            "lib/Dialect/LoopOps/*.h",
-            "lib/Dialect/LoopOps/EDSC/*.cpp",
+            "lib/Dialect/SCF/*.cpp",
+            "lib/Dialect/SCF/*.h",
+            "lib/Dialect/SCF/EDSC/*.cpp",
         ],
     ),
     hdrs = glob([
-        "include/mlir/Dialect/LoopOps/*.h",
-        "include/mlir/Dialect/LoopOps/EDSC/*.h",
+        "include/mlir/Dialect/SCF/*.h",
+        "include/mlir/Dialect/SCF/EDSC/*.h",
     ]),
     includes = ["include"],
     deps = [
@@ -606,7 +606,7 @@
         ":EDSC",
         ":IR",
         ":LoopLikeInterface",
-        ":LoopOpsIncGen",
+        ":SCFIncGen",
         ":SideEffects",
         ":StandardOps",
         ":Support",
@@ -1113,9 +1113,9 @@
         ":GPUDialect",
         ":GPUPassIncGen",
         ":IR",
-        ":LoopOps",
         ":ParallelLoopMapperAttrGen",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":Transforms",
@@ -1324,8 +1324,8 @@
         ":GPUDialect",
         ":GPUToSPIRVIncGen",
         ":IR",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":SPIRVDialect",
         ":SPIRVLowering",
         ":StandardToSPIRVConversions",
@@ -1883,7 +1883,7 @@
         ":ControlFlowInterfaces",
         ":IR",
         ":LoopLikeInterface",
-        ":LoopOps",
+        ":SCFDialect",
         ":SideEffects",
         ":StandardOps",
         ":Support",
@@ -2000,8 +2000,8 @@
         ":ControlFlowInterfaces",
         ":IR",
         ":LoopLikeInterface",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":SideEffects",
         ":StandardOps",
         ":Support",
@@ -2037,8 +2037,8 @@
         ":GPUDialect",
         ":GPUTransforms",
         ":IR",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":TransformUtils",
@@ -2061,9 +2061,9 @@
         ":Affine",
         ":ConversionPassIncGen",
         ":GPUDialect",
-        ":LoopOps",
         ":LoopsToGPU",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":Transforms",
@@ -2085,8 +2085,8 @@
         ":ConversionPassIncGen",
         ":IR",
         ":LLVMDialect",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":TransformUtils",
@@ -2292,7 +2292,7 @@
         ":Affine",
         ":CallOpInterfaces",
         ":IR",
-        ":LoopOps",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         "@llvm-project//llvm:support",
@@ -2479,10 +2479,10 @@
         ":LLVMTransforms",
         ":LinalgToLLVM",
         ":LinalgToSPIRV",
-        ":LoopOpsTransforms",
         ":NVVMDialect",
         ":Parser",
         ":Pass",
+        ":SCFTransforms",
         ":StandardOpsTransforms",
         ":StandardToSPIRVConversions",
         ":StandardToStandard",
@@ -2566,8 +2566,6 @@
         ":LinalgToLLVM",
         ":LinalgToSPIRV",
         ":LinalgTransforms",
-        ":LoopOps",
-        ":LoopOpsTransforms",
         ":LoopPassIncGen",
         ":LoopsToGPUPass",
         ":NVVMDialect",
@@ -2575,6 +2573,8 @@
         ":QuantOps",
         ":QuantPassIncGen",
         ":ROCDLDialect",
+        ":SCFDialect",
+        ":SCFTransforms",
         ":SDBM",
         ":SPIRVDialect",
         ":SPIRVLowering",
@@ -3245,8 +3245,8 @@
         ":LinalgOps",
         ":LinalgPassIncGen",
         ":LinalgStructuredOpsIncGen",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":TransformUtils",
@@ -3367,8 +3367,8 @@
         ":IR",
         ":LLVMDialect",
         ":LLVMTransforms",
-        ":LoopOps",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":Support",
         ":Transforms",
diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD
index c3dd157..a0312a5 100644
--- a/third_party/mlir/test.BUILD
+++ b/third_party/mlir/test.BUILD
@@ -163,8 +163,8 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
         "@llvm-project//mlir:LinalgTransforms",
-        "@llvm-project//mlir:LoopOps",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl
index b1d0389..9be398f 100644
--- a/third_party/toolchains/preconfig/generate/containers.bzl
+++ b/third_party/toolchains/preconfig/generate/containers.bzl
@@ -9,7 +9,7 @@
     "cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9",
     "cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432",
     "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:cc7f760195d7bbe283b45ae740409751d0b74d8ffbdc2f7a3cb62c71a71fbe25",
-    "cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython": "sha256:c460570b88eab3da92f06fdf30098d89be4de0f3b010ee3d39086f4d000dd3b8",
+    "cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython": "sha256:13aa5e700bb609521cd4365d4152d7d8f4118cae7ce174ce7d54cc529e21766a",
     "rocm-ubuntu16.04": "sha256:e645447dd6127325f3e97b8bf23424f637a8579d963b34fcc6772cf7cfaa0ebe",
     "windows-1803": "sha256:f109576c7c0c8a1783ff22b666e8923b52dbbe7933f69a1c7a7275202c304a12",
 }