Merge pull request #42503 from ahmedsabie:T1

PiperOrigin-RevId: 329267324
Change-Id: I8829d03b6d20531502ae9f638864c9a8d82470fa
diff --git a/.github/bot_config.yml b/.github/bot_config.yml
index d0e7256..952ff91 100644
--- a/.github/bot_config.yml
+++ b/.github/bot_config.yml
@@ -40,6 +40,22 @@
 # assignees
 filesystem_security_assignee:
    - mihaimaruseac
+   
+tflite_micro_path:
+   - tensorflow/lite/micro
+   
+tflite_micro_comment: >
+   Thanks for contributing to TensorFlow Lite Micro.
+   
+
+   To keep this process moving along, we'd like to make sure that you have completed the items on this list:
+      * Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md)
+      * Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
+      * Linked to the issue from the PR description
+      
+
+   We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review.
+
 # Cuda Comment
 cuda_comment: >
    From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
diff --git a/RELEASE.md b/RELEASE.md
index 7057657..6890352 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -37,6 +37,9 @@
 * XLA:CPU and XLA:GPU devices are no longer registered by default. Use
   `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
   removed).
+* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
+  `tf.complex64` or `tf.complex128`, because the behavior of these ops is not
+  well defined for complex types.
 
 ## Known Caveats
 
@@ -120,6 +123,13 @@
       customization of how gradients are aggregated across devices, as well as
       `gradients_transformers` to allow for custom gradient transformations
       (such as gradient clipping).
+    * The `steps_per_execution` argument in `compile()` is no longer
+      experimental; if you were passing `experimental_steps_per_execution`,
+      rename it to `steps_per_execution` in your code. This argument controls
+      the number of batches to run during each `tf.function` call when calling
+      `fit()`. Running multiple batches inside a single `tf.function` call can
+      greatly improve performance on TPUs or small models with a large Python
+      overhead.
 * `tf.function` / AutoGraph:
   * Added `experimental_follow_type_hints` argument for `tf.function`. When
     True, the function may use type annotations to optimize the tracing
@@ -147,6 +157,8 @@
     * Deprecate `Interpreter::UseNNAPI(bool)` C++ API
       * Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
     * Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
+    * TFLite Profiler for Android is available. See the detailed
+      [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
     * <ADD RELEASE NOTES HERE>
 *   `tf.random`:
     * <ADD RELEASE NOTES HERE>
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 0b3aff7..0cace4d 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -58,9 +58,9 @@
     visibility = ["//visibility:public"],
 )
 
-filegroup(
+cc_library(
     name = "pywrap_required_hdrs",
-    srcs = [
+    textual_hdrs = [
         "c_api_internal.h",
         "c_api_macros.h",
         "conversion_macros.h",
@@ -387,6 +387,7 @@
         "//tensorflow/core/common_runtime/eager:eager_operation",
         "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
         "//tensorflow/core/platform",
+        "//tensorflow/core/platform:blocking_counter",
         "@com_google_absl//absl/strings",
     ],
     alwayslink = 1,
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index b429703..81fb9d1 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -35,6 +35,7 @@
 #include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/casts.h"
 #include "tensorflow/core/platform/init_main.h"
 #include "tensorflow/core/platform/net.h"
@@ -560,6 +561,21 @@
   collective_executor_handle->get()->StartAbort(status->status);
 }
 
+TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
+                                                            const char* task,
+                                                            TF_Status* status) {
+  tensorflow::EagerContext* context =
+      tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
+  auto collective_executor_handle = context->GetCollectiveExecutorHandle();
+  tensorflow::Notification done;
+  collective_executor_handle->get()->remote_access()->CheckPeerHealth(
+      task, [&done, status](const Status& s) {
+        status->status = s;
+        done.Notify();
+      });
+  done.WaitForNotification();
+}
+
 TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
   TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
   result->num_items = num_items;
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index ebd14b4..a08d4f2 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -238,6 +238,13 @@
 TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
                                                   TF_Status* status);
 
+// Checks the health of collective ops peers. Explicit health check is needed in
+// multi worker collective ops to detect failures in the cluster.  If a peer is
+// down, collective ops may hang.
+TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
+                                                            const char* task,
+                                                            TF_Status* status);
+
 // Information about the shape of a Tensor and its type.
 struct TF_ShapeAndType {
   // Number of dimensions. -1 indicates unknown rank.
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 3fff9bc..ec8cfe4 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -1704,66 +1704,5 @@
   TF_DeleteFunction(func1);
 }
 
-// This test only works when the TF build includes XLA compiler. One way to set
-// this up is via bazel build option "--define with_xla_support=true".
-//
-// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
-// something like TENSORFLOW_CAPI_USE_XLA.
-#ifdef TENSORFLOW_EAGER_USE_XLA
-TEST_F(CApiFunctionTest, StatelessIf_XLA) {
-  TF_Function* func;
-  const std::string funcName = "BranchFunc";
-  DefineFunction(funcName.c_str(), &func);
-  TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-
-  TF_Operation* feed = Placeholder(host_graph_, s_);
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-
-  TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-
-  TF_OperationDescription* desc =
-      TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
-  TF_AddInput(desc, {true_cond, 0});
-  TF_Output inputs[] = {{feed, 0}};
-  TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-  TF_SetAttrType(desc, "Tcond", TF_BOOL);
-  TF_DataType inputType = TF_INT32;
-  TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
-  TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
-  TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
-  TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
-  TF_SetDevice(desc, "/device:XLA_CPU:0");
-  auto op = TF_FinishOperation(desc, s_);
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-  ASSERT_NE(op, nullptr);
-
-  // Create a session for this graph.
-  CSession csession(host_graph_, s_, /*use_XLA*/ true);
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-
-  // Run the graph.
-  csession.SetInputs({{feed, Int32Tensor(17)}});
-  csession.SetOutputs({op});
-  csession.Run(s_);
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-  TF_Tensor* out = csession.output_tensor(0);
-  ASSERT_TRUE(out != nullptr);
-  EXPECT_EQ(TF_INT32, TF_TensorType(out));
-  EXPECT_EQ(0, TF_NumDims(out));  // scalar
-  ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
-  int32* output_contents = static_cast<int32*>(TF_TensorData(out));
-  EXPECT_EQ(-17, *output_contents);
-
-  // Clean up
-  csession.CloseAndDelete(s_);
-  ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-
-  TF_DeleteFunction(func);
-}
-#endif  // TENSORFLOW_EAGER_USE_XLA
-
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 451ade4..1a3b348 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -6,7 +6,6 @@
     "tf_copts",
     "tf_cuda_cc_test",
     "tf_cuda_library",
-    "tfe_xla_copts",
 )
 load(
     "//tensorflow/core/platform:build_config.bzl",
@@ -31,7 +30,7 @@
         "c_api_unified_experimental.h",
     ],
     hdrs = ["c_api.h"],
-    copts = tf_copts() + tfe_xla_copts(),
+    copts = tf_copts(),
     visibility = ["//visibility:public"],
     deps = select({
         "//tensorflow:android": [
@@ -72,13 +71,6 @@
             "//tensorflow/core:protos_all_cc",
             "//tensorflow/core/profiler/lib:traceme",
         ],
-    }) + select({
-        "//tensorflow:with_xla_support": [
-            "//tensorflow/compiler/tf2xla:xla_compiler",
-            "//tensorflow/compiler/jit",
-            "//tensorflow/compiler/jit:xla_device",
-        ],
-        "//conditions:default": [],
     }) + [
         "@com_google_absl//absl/memory",
         "//tensorflow/core/common_runtime/eager:eager_operation",
@@ -228,7 +220,6 @@
         "gradients_test.cc",
     ],
     args = ["--heap_check=local"],
-    extra_copts = tfe_xla_copts(),
     linkstatic = tf_kernel_tests_linkstatic(),
     tags = tf_cuda_tests_tags() + ["nomac"],
     deps = [
@@ -257,6 +248,73 @@
 )
 
 cc_library(
+    name = "mnist_gradients_testutil",
+    srcs = [
+        "mnist_gradients_testutil.cc",
+    ],
+    hdrs = [
+        "mnist_gradients_testutil.h",
+    ],
+    visibility = [
+        "//tensorflow:internal",
+    ],
+    deps = [
+        ":abstract_tensor_handle",
+        ":c_api_experimental",
+        ":c_api_unified_internal",
+        ":gradients_internal",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/c:tf_tensor",
+        "//tensorflow/c/experimental/ops:array_ops",
+        "//tensorflow/c/experimental/ops:math_ops",
+        "//tensorflow/c/experimental/ops:nn_ops",
+        "//tensorflow/core/lib/llvm_rtti",
+        "//tensorflow/core/platform:status",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+tf_cuda_cc_test(
+    name = "mnist_gradients_test",
+    size = "small",
+    srcs = [
+        "mnist_gradients_test.cc",
+    ],
+    args = ["--heap_check=local"],
+    linkstatic = tf_kernel_tests_linkstatic(),
+    tags = tf_cuda_tests_tags() + [
+        "nomac",
+    ],
+    deps = [
+        ":abstract_tensor_handle",
+        ":c_api_experimental",
+        ":c_api_test_util",
+        ":c_api_unified_internal",
+        ":gradients_internal",
+        ":mnist_gradients_testutil",
+        "//tensorflow/c:c_api",
+        "//tensorflow/c:c_test_util",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/c/experimental/gradients:math_grad",
+        "//tensorflow/c/experimental/gradients:nn_grad",
+        "//tensorflow/c/experimental/ops:array_ops",
+        "//tensorflow/c/experimental/ops:math_ops",
+        "//tensorflow/c/experimental/ops:nn_ops",
+        "//tensorflow/cc/profiler",
+        "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/lib/llvm_rtti",
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
     name = "abstract_tensor_handle",
     hdrs = ["abstract_tensor_handle.h"],
     visibility = [
@@ -484,7 +542,6 @@
         "c_api_debug_test.cc",
         "c_api_test.cc",
     ],
-    extra_copts = tfe_xla_copts(),
     tags = [
         "noguitar",  # TODO(b/155445984): flaky
         #"guitar",
@@ -539,7 +596,6 @@
     ],
     # TODO(b/136478427): Figure out how to correctly shut the server down
     args = ["--heap_check=local"],
-    extra_copts = tfe_xla_copts(),
     tags = [
         "no_windows",
     ],
@@ -572,7 +628,6 @@
     ],
     # TODO(b/136478427): Figure out how to correctly shut the server down
     args = ["--heap_check=local"],
-    extra_copts = tfe_xla_copts(),
     tags = [
         "no_windows",
     ],
@@ -591,7 +646,6 @@
     ],
     # TODO(b/136478427): Figure out how to correctly shut the server down
     args = ["--heap_check=local"],
-    extra_copts = tfe_xla_copts(),
     tags = [
         "no_windows",
         "noasan",  # leaks gRPC server instances
@@ -625,7 +679,6 @@
     ],
     # TODO(b/136478427): Figure out how to correctly shut the server down
     args = ["--heap_check=local"],
-    extra_copts = tfe_xla_copts(),
     tags = [
         "no_windows",
     ],
@@ -660,7 +713,7 @@
         "c_api_experimental.h",
         "c_api_unified_experimental.h",
     ],
-    copts = tf_copts() + tfe_xla_copts(),
+    copts = tf_copts(),
     visibility = ["//visibility:public"],
     deps = select({
         "//tensorflow:android": [
@@ -732,7 +785,6 @@
         "c_api_experimental_test.cc",
     ],
     args = ["--heap_check=local"],
-    extra_copts = tfe_xla_copts(),
     linkstatic = tf_kernel_tests_linkstatic(),
     tags = tf_cuda_tests_tags() + ["nomac"],
     deps = [
@@ -756,7 +808,6 @@
         "c_api_unified_experimental_test.cc",
     ],
     args = ["--heap_check=local"],
-    extra_copts = tfe_xla_copts(),
     linkstatic = tf_kernel_tests_linkstatic(),
     tags = tf_cuda_tests_tags() + ["nomac"],
     deps = [
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index fefa753..6528d21 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -51,9 +51,6 @@
 #include "tensorflow/core/protobuf/device_filters.pb.h"
 #include "tensorflow/core/protobuf/error_codes.pb.h"
 #include "tensorflow/core/util/device_name_utils.h"
-#ifdef TENSORFLOW_EAGER_USE_XLA
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#endif  // TENSORFLOW_EAGER_USE_XLA
 #include "tensorflow/core/common_runtime/copy_tensor.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
@@ -745,7 +742,6 @@
       opts->session_options.options,
       static_cast<tensorflow::ContextDevicePlacementPolicy>(
           opts->device_placement_policy),
-      static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
       opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
       /*device_mgr_owned*/ true, r,
       tensorflow::GetDefaultCustomKernelCreator()));
@@ -1149,26 +1145,23 @@
   tensorflow::unwrap(op)->Release();
 }
 
+const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
+  return tensorflow::unwrap(op)->Name().c_str();
+}
+
+TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
+  return tensorflow::wrap(
+      &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
+}
+
 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
   status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
 }
 
-const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
+const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
   return tensorflow::unwrap(op)->DeviceName().c_str();
 }
 
-void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
-#ifdef TENSORFLOW_EAGER_USE_XLA
-  tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
-  if (!s.ok()) {
-    LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
-  }
-#else
-  LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
-                  "built with XLA support.";
-#endif  // TENSORFLOW_EAGER_USE_XLA
-}
-
 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
   status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
 }
@@ -1181,6 +1174,15 @@
        static_cast<size_t>(num_inputs)});
 }
 
+extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
+  return tensorflow::unwrap(op)->GetInputs().size();
+}
+
+extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
+                                            TF_Status* status) {
+  return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
+}
+
 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
                               unsigned char* is_list, TF_Status* status) {
   TF_AttrType ret = TF_ATTR_INT;
@@ -1486,7 +1488,7 @@
   tensorflow::unwrap(ctx)->EndStep();
 }
 
-const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
+const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
   return tensorflow::wrap(
       &OperationFromInterface(tensorflow::unwrap(op))->Attrs());
 }
@@ -1551,8 +1553,67 @@
       TFE_OpSetAttrFunction(op, attr_name, func_op);
       TFE_DeleteOp(func_op);
     } break;
-    case tensorflow::AttrValue::kList:
-      TF_FALLTHROUGH_INTENDED;
+    case tensorflow::AttrValue::kList: {
+      // String
+      if (const int s_size = default_value.list().s_size()) {
+        absl::InlinedVector<const void*, 4> values_vector;
+        absl::InlinedVector<size_t, 4> lengths_vector;
+        for (int i = 0; i < s_size; ++i) {
+          const string& v = default_value.list().s(i);
+          values_vector.push_back(v.data());
+          lengths_vector.push_back(v.size());
+        }
+        TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
+                                lengths_vector.data(), s_size);
+      }
+
+      // Int
+      if (const int i_size = default_value.list().i_size()) {
+        absl::InlinedVector<int64_t, 4> i_vector;
+        for (int i = 0; i < i_size; ++i) {
+          i_vector.push_back(default_value.list().i(i));
+        }
+        TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
+      }
+      // Float
+      if (const int f_size = default_value.list().f_size()) {
+        absl::InlinedVector<float, 4> f_vector;
+        for (int i = 0; i < f_size; ++i) {
+          f_vector.push_back(default_value.list().f(i));
+        }
+        TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
+      }
+      // Bool
+      if (const int b_size = default_value.list().b_size()) {
+        absl::InlinedVector<unsigned char, 4> b_vector;
+        for (int i = 0; i < b_size; i++) {
+          b_vector.push_back(default_value.list().b(i));
+        }
+        TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
+      }
+      // Type
+      if (const int type_size = default_value.list().type_size()) {
+        absl::InlinedVector<unsigned int, 4> type_vector;
+        for (int i = 0; i < type_size; ++i) {
+          type_vector.push_back(default_value.list().type(i));
+        }
+        TFE_OpSetAttrTypeList(
+            op, attr_name,
+            reinterpret_cast<const TF_DataType*>(type_vector.data()),
+            type_size);
+      }
+
+      // Rest are not supported.
+      if (default_value.list().shape_size() > 0 ||
+          default_value.list().func_size() > 0 ||
+          default_value.list().tensor_size() > 0) {
+        TF_SetStatus(
+            status, TF_UNIMPLEMENTED,
+            tensorflow::strings::StrCat("Unable to get setfor default value: ",
+                                        default_value.DebugString())
+                .data());
+      }
+    } break;
     case tensorflow::AttrValue::kTensor:
       TF_FALLTHROUGH_INTENDED;
     case tensorflow::AttrValue::kPlaceholder:
@@ -1612,19 +1673,12 @@
     return status.status;
   }
 
-  tensorflow::Status Execute(tensorflow::EagerOperation* op,
+  tensorflow::Status Execute(const tensorflow::EagerOperation* op,
                              tensorflow::TensorHandle** retvals,
                              int* num_retvals) override {
-    std::vector<TFE_TensorHandle*> inputs;
-    inputs.reserve(op->Inputs().size());
-    for (int i = 0; i < op->Inputs().size(); ++i) {
-      op->Inputs()[i]->Ref();
-      inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
-    }
     std::vector<TFE_TensorHandle*> outputs(*num_retvals);
     TF_Status status;
-    device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
-                    wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
+    device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
                     info_);
     if (status.status.ok()) {
       for (int i = 0; i < *num_retvals; ++i) {
@@ -1634,10 +1688,6 @@
         TFE_DeleteTensorHandle(outputs[i]);
       }
     }
-
-    for (auto inp : inputs) {
-      TFE_DeleteTensorHandle(inp);
-    }
     return status.status;
   }
 
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 5afe304..a58c681 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -248,22 +248,22 @@
 TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
                                         const char* op_or_function_name,
                                         TF_Status* status);
-
 TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
 
+// Returns the op or function name `op` will execute.
+//
+// The returned string remains valid throughout the lifetime of 'op'.
+TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op,
+                                                TF_Status* status);
+TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op,
+                                                    TF_Status* status);
+
 TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
                                            TF_Status* status);
 // The returned string remains valid throughout the lifetime of 'op'.
-TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
+TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op,
                                                   TF_Status* status);
 
-// When 'enable' is set to 1, and if TensorFlow library is built with XLA
-// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
-//
-// If the library is not built with XLA support, this call would be a no-op.
-TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
-                                                   unsigned char enable);
-
 TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input,
                                           TF_Status* status);
 
@@ -272,6 +272,23 @@
                                               int num_inputs,
                                               TF_Status* status);
 
+// Fetches the current number of inputs attached to `op`.
+//
+// Does not use the operation's definition to determine how many inputs should
+// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an
+// already-finalized operation.
+//
+// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat
+// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a
+// particular named input list, which may only be part of the op's inputs).
+TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op,
+                                                  TF_Status* status);
+// Returns a borrowed reference to one of `op`'s inputs. Use
+// `TFE_TensorHandleCopySharingTensor` to make a new reference.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op,
+                                                           int index,
+                                                           TF_Status* status);
+
 TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op,
                                                     const char* attr_name,
                                                     unsigned char* is_list,
diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc
index dd55f05..b5721cd 100644
--- a/tensorflow/c/eager/c_api_debug.cc
+++ b/tensorflow/c/eager/c_api_debug.cc
@@ -22,9 +22,6 @@
 #include "tensorflow/c/tf_status_internal.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
 #include "tensorflow/core/platform/status.h"
-#ifdef TENSORFLOW_EAGER_USE_XLA
-#include "tensorflow/compiler/jit/xla_device.h"
-#endif  // TENSORFLOW_EAGER_USE_XLA
 
 using tensorflow::string;
 
@@ -64,87 +61,6 @@
     return nullptr;
   }
 
-#ifdef TENSORFLOW_EAGER_USE_XLA
-  auto* device = absl::get<tensorflow::Device*>(handle->device());
-
-  // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
-  auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
-  if (xla_device != nullptr) {
-    tensorflow::XlaDevice::PaddedShapeFn shape_fn =
-        xla_device->metadata().padded_shape_fn();
-    xla::Shape padded_shape;
-    status->status = shape_fn(*tensor, &padded_shape);
-    if (!status->status.ok()) {
-      return nullptr;
-    }
-    if (VLOG_IS_ON(3)) {
-      std::vector<tensorflow::int64> shape_to_log =
-          TensorShapeAsVector(*handle, &status->status);
-      if (!status->status.ok()) {
-        // Ignore the status here as we are simply logging.
-        status->status = tensorflow::Status::OK();
-      } else {
-        VLOG(3) << "Fully padded shape of ["
-                << absl::StrJoin(shape_to_log, ", ") << "] is "
-                << padded_shape.DebugString();
-      }
-    }
-
-    if (padded_shape.IsTuple()) {
-      if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
-        // Currently, the only case of XlaTensor containing a tuple shape is to
-        // represent 64 bit ints, doubles, and complex numbers (we don't support
-        // 64bit complex numbers).
-        status->status = tensorflow::errors::InvalidArgument(
-            "XlaTensors should only contain tuples of size 2. Shape: ",
-            padded_shape.DebugString());
-        return nullptr;
-      }
-
-      // shape0 is not a const& because we will assign it to padded_shape below.
-      // It is illegal to assign a part of a message to itself.
-      xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
-      const xla::Shape& shape1 =
-          xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
-      if (shape0.IsTuple() || shape1.IsTuple()) {
-        status->status = tensorflow::errors::InvalidArgument(
-            "XlaTensors should not contain nested tuples. Shape: ",
-            padded_shape.DebugString());
-        return nullptr;
-      }
-      if (!xla::ShapeUtil::Equal(shape0, shape1)) {
-        status->status = tensorflow::errors::InvalidArgument(
-            "Subshapes of XlaTensors should be the same. Shape: ",
-            padded_shape.DebugString());
-        return nullptr;
-      }
-
-      // Since the only case we handle here are two equal subshapes, we
-      // simply return one of them. The caller will interpret it as this
-      // shape directly storing the 64bit types. This approximation is good
-      // enough for this API's debugging use case.
-      padded_shape = shape0;
-    }
-
-    int rank = padded_shape.dimensions_size();
-    std::vector<tensorflow::int64> dev_dims;
-    dev_dims.reserve(rank);
-    if (rank == 1) {
-      // Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
-      dev_dims.push_back(padded_shape.dimensions(0));
-    } else {
-      for (int i = rank - 1; i >= 0; --i) {
-        tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i);
-        dev_dims.push_back(padded_shape.dimensions(dim_index));
-      }
-    }
-    status->status = tensorflow::Status::OK();
-    return new TFE_TensorDebugInfo(dev_dims);
-  }
-#endif  // TENSORFLOW_EAGER_USE_XLA
-
-  // If the tensor is not an XLA tensor, the device shape is
-  // the same as regular tensor shape.
   std::vector<tensorflow::int64> dev_dims =
       TensorShapeAsVector(*handle, &status->status);
   if (!status->status.ok()) {
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index 7390cf2..eabb159 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -486,29 +486,6 @@
       static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
 }
 
-void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
-                                          TFE_ContextMirroringPolicy policy) {
-  options->mirroring_policy = policy;
-}
-
-void TFE_ContextSetThreadLocalMirroringPolicy(
-    TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
-  tensorflow::EagerContext* context =
-      tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
-  context->SetThreadLocalMirroringPolicy(
-      static_cast<tensorflow::ContextMirroringPolicy>(policy));
-}
-
-// Note: this function looks up a thread local policy. So it should be called in
-// the appropriate client thread. In particular, in async mode, it may not be
-// safe to call this function from the async EagerExecutor threads.
-extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
-    TFE_Context* ctx) {
-  tensorflow::EagerContext* context =
-      tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
-  return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
-}
-
 void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
                                                bool lazy_copy) {
   options->lazy_remote_inputs_copy = lazy_copy;
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 1af76c0..12546c6 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -265,33 +265,6 @@
 TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
     TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
 
-// LINT.IfChange
-// Note: Keep in sync with internal copy of enum in eager/context.h.
-typedef enum TFE_ContextMirroringPolicy {
-  // Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
-  // copies with their own lifetime.
-  TFE_MIRRORING_NONE = 0,
-  // Mirroring any remote tensor handles, associating them with the lifetime of
-  // the local TensorHandle.
-  TFE_MIRRORING_ALL = 1,
-} TFE_ContextMirroringPolicy;
-// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
-
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetMirroringPolicy(
-    TFE_ContextOptions*, TFE_ContextMirroringPolicy);
-
-// Sets a thread-local mirroring policy. After this call, other calls to
-// TFE_Execute in the same thread will use the mirroring policy specified here
-// instead of the mirroring policy used to construct the context. This has no
-// effect on the mirroring policy used by other program threads.
-TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalMirroringPolicy(
-    TFE_Context*, TFE_ContextMirroringPolicy);
-
-// Returns the mirroring policy to be used by this context in the current
-// thread.
-TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
-    TFE_Context*);
-
 // Sets whether to copy the remote inputs of a function lazily.
 TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
     TFE_ContextOptions*, bool lazy_copy);
@@ -441,7 +414,7 @@
 
 // Fetch a reference to `op`'s attributes. The returned reference is only valid
 // while `op` is alive.
-const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
+TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
 // Add attributes in `attrs` to `op`.
 //
 // Does not overwrite or update existing attributes, but adds new ones.
@@ -462,7 +435,11 @@
                                                    size_t proto_len,
                                                    TF_Status* status);
 
-#define TFE_CUSTOM_DEVICE_VERSION 2
+// TODO(b/166642410): It would be nice, for custom devices and for other users,
+// to have a non-string representation of devices (TF_Device) extracted from
+// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
+
+#define TFE_CUSTOM_DEVICE_VERSION 3
 
 // Struct to be filled in
 typedef struct TFE_CustomDevice {
@@ -481,9 +458,16 @@
                                                void* device_info);
 
   // Method to execute an operation.
-  void (*execute)(TFE_Context* context, int num_inputs,
-                  TFE_TensorHandle** inputs, const char* operation_name,
-                  const TFE_OpAttrs* attributes, int* num_outputs,
+  //
+  // Arguments provide enough information to reconstruct the original `TFE_Op`,
+  // or construct a transformed version, by inspecting the passed `op`.
+  //
+  // TFE_OpGetDevice(op) records the original placement of the operation. It may
+  // be an empty string if no device was explicitly requested, but will
+  // otherwise be the name of this custom device. Ops are placed onto a custom
+  // device if any of their inputs are on that custom device, but custom devices
+  // are free to set a bad status in order to require explicit placement.
+  void (*execute)(const TFE_Op* op, int* num_outputs,
                   TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
 
   // Method to delete a device.
diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc
index a4d3141..4975d30 100644
--- a/tensorflow/c/eager/c_api_experimental_test.cc
+++ b/tensorflow/c/eager/c_api_experimental_test.cc
@@ -316,86 +316,6 @@
   TF_DeleteStatus(status);
 }
 
-#ifdef TENSORFLOW_EAGER_USE_XLA
-TEST(CAPI, Function_ident_XLA_CPU) {
-  // First create a simple identity function.
-  TF_Graph* function_graph = TF_NewGraph();
-  TF_OperationDescription* arg_descr =
-      TF_NewOperation(function_graph, "Placeholder", "arg");
-  TF_SetAttrType(arg_descr, "dtype", TF_INT32);
-  TF_Status* status = TF_NewStatus();
-  TF_Operation* arg = TF_FinishOperation(arg_descr, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_OperationDescription* id_descr =
-      TF_NewOperation(function_graph, "Identity", "id");
-  TF_SetAttrType(id_descr, "T", TF_INT32);
-  TF_AddInput(id_descr, {arg, 0});
-  TF_Operation* id = TF_FinishOperation(id_descr, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_Output input{arg, 0};
-  TF_Output output{id, 0};
-  TF_Function* fn =
-      TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
-                         &output, nullptr, nullptr, "test", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteGraph(function_graph);
-  TFE_ContextOptions* opts = TFE_NewContextOptions();
-  TFE_Context* ctx = TFE_NewContext(opts, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_DeleteContextOptions(opts);
-  TFE_ContextAddFunction(ctx, fn, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteFunction(fn);
-
-  for (bool async : {false, true, false}) {
-    TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
-    TFE_Executor* executor = TFE_NewExecutor(async);
-    TFE_ContextSetExecutorForThread(ctx, executor);
-    CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK);
-    TF_Tensor* t =
-        TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
-    *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
-    TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    TF_DeleteTensor(t);
-
-    TFE_Op* op = TFE_NewOp(ctx, "ident", status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    TFE_OpAddInput(op, h, status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-
-    // Now run it via XLA.
-    TFE_OpSetXLACompilation(op, true);
-
-    std::vector<TFE_TensorHandle*> result;
-    result.push_back(nullptr);
-    int num_retvals = 1;
-    TFE_Execute(op, result.data(), &num_retvals, status);
-    TFE_DeleteOp(op);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    ASSERT_EQ(num_retvals, 1);
-
-    TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
-    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-    EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
-    TFE_ContextSetExecutorForThread(ctx, old_executor);
-    TFE_ExecutorWaitForAllPendingNodes(executor, status);
-    ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-    TFE_DeleteExecutor(executor);
-    TFE_DeleteExecutor(old_executor);
-    TFE_DeleteTensorHandle(h);
-    TF_DeleteTensor(r);
-    TFE_DeleteTensorHandle(result[0]);
-  }
-  TFE_ContextRemoveFunction(ctx, "ident", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_DeleteContext(ctx);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteStatus(status);
-}
-#endif  // TENSORFLOW_EAGER_USE_XLA
-
 void Executor_MatMul_CPU(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 4d9be0c..356476c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -32,7 +32,6 @@
   bool async = false;
   TFE_ContextDevicePlacementPolicy device_placement_policy{
       TFE_DEVICE_PLACEMENT_SILENT};
-  TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
   // If true, lazily copy the remote inputs of a function to the target devices.
   bool lazy_remote_inputs_copy = true;
   // If true, use TFRT backend
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 7241765..fd208c6 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -20,6 +20,7 @@
 #include <string>
 
 // clang-format off
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/platform/platform.h"
 // clang-format on
 
@@ -876,89 +877,6 @@
   TF_DeleteStatus(status);
 }
 
-#ifdef TENSORFLOW_EAGER_USE_XLA
-void Execute_MatMul_XLA_CPU(bool async) {
-  TF_Status* status = TF_NewStatus();
-  TFE_ContextOptions* opts = TFE_NewContextOptions();
-  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
-  TFE_Context* ctx = TFE_NewContext(opts, status);
-  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  TFE_DeleteContextOptions(opts);
-
-  TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
-  TFE_Op* matmul = MatMulOp(ctx, m, m);
-
-  TFE_OpSetXLACompilation(matmul, true);
-
-  TFE_TensorHandle* retvals[1] = {nullptr};
-  int num_retvals = 1;
-  TFE_Execute(matmul, &retvals[0], &num_retvals, status);
-  // Running a primitive TF operator via XLA is not yet supported.
-  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-
-  TFE_DeleteOp(matmul);
-  TFE_DeleteTensorHandle(m);
-  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-
-  EXPECT_EQ(1, num_retvals);
-
-  TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
-  TFE_DeleteTensorHandle(retvals[0]);
-  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  float product[4] = {0};
-  EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
-  memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
-  TF_DeleteTensor(t);
-  EXPECT_EQ(7, product[0]);
-  EXPECT_EQ(10, product[1]);
-  EXPECT_EQ(15, product[2]);
-  EXPECT_EQ(22, product[3]);
-  TFE_DeleteContext(ctx);
-  TF_DeleteStatus(status);
-}
-TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
-TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
-
-void Execute_Min_XLA_CPU(bool async) {
-  TF_Status* status = TF_NewStatus();
-  TFE_ContextOptions* opts = TFE_NewContextOptions();
-  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
-  TFE_Context* ctx = TFE_NewContext(opts, status);
-  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  TFE_DeleteContextOptions(opts);
-
-  TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
-  TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
-  TFE_Op* minOp = MinOp(ctx, input, axis);
-
-  TFE_OpSetXLACompilation(minOp, true);
-
-  TFE_TensorHandle* retvals[1] = {nullptr};
-  int num_retvals = 1;
-  TFE_Execute(minOp, &retvals[0], &num_retvals, status);
-  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  TFE_DeleteOp(minOp);
-  TFE_DeleteTensorHandle(input);
-  TFE_DeleteTensorHandle(axis);
-  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  ASSERT_EQ(1, num_retvals);
-
-  TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
-  TFE_DeleteTensorHandle(retvals[0]);
-  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  float output[2] = {0};
-  EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
-  memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
-  TF_DeleteTensor(t);
-  EXPECT_EQ(1, output[0]);
-  EXPECT_EQ(3, output[1]);
-  TFE_DeleteContext(ctx);
-  TF_DeleteStatus(status);
-}
-TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
-TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
-#endif  // TENSORFLOW_EAGER_USE_XLA
-
 void ExecuteWithTracing(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -1274,6 +1192,68 @@
   TF_DeleteStatus(status);
 }
 
+// Same test as above, expect use SetOpAttrValueScalar to set attrs.
+TEST(CAPI, TestTFE_SetOpAttrs) {
+  // Test that TFE_OpSetAttrString doesn't hold on to the value after it
+  // returns.
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  std::vector<int64_t> dims(4, 1);
+  TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  TF_Tensor* tensor =
+      TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
+  float tensor_data[] = {1};
+  memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
+  TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpAddInput(op, tensor_handle, status);
+  TF_DeleteTensor(tensor);
+  TFE_DeleteTensorHandle(tensor_handle);
+
+  tensorflow::AttrValue i_list_values;
+  for (int i = 0; i < 4; ++i) {
+    i_list_values.mutable_list()->add_i(1);
+  }
+  SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
+  SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
+
+  tensorflow::AttrValue padding_value;
+  *padding_value.mutable_s() = "VALID";
+  tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
+
+  tensorflow::AttrValue data_format_value;
+  *data_format_value.mutable_s() = "NHWC";
+  tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
+                                   status);
+
+  TFE_OpSetAttrType(op, "T", TF_FLOAT);
+
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  TFE_TensorHandle* retvals[1];
+  int num_retvals = 1;
+  TFE_Execute(op, &retvals[0], &num_retvals, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  ASSERT_EQ(1, num_retvals);
+
+  tensor = TFE_TensorHandleResolve(retvals[0], status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  EXPECT_EQ(4, TF_TensorByteSize(tensor));
+  TF_DeleteTensor(tensor);
+  TFE_DeleteTensorHandle(retvals[0]);
+
+  TFE_DeleteOp(op);
+
+  TFE_DeleteContext(ctx);
+  TF_DeleteStatus(status);
+}
+
 TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
@@ -1620,4 +1600,91 @@
   TFE_DeleteContext(ctx);
 }
 
+// Needs to work with a const TFE_Op since custom devices should not modify the
+// op they are called with.
+TFE_Op* CloneOp(const TFE_Op* other) {
+  TF_Status* status = TF_NewStatus();
+  TFE_Context* context = TFE_OpGetContext(other, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  const char* op_name = TFE_OpGetName(other, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_Op* ret = TFE_NewOp(context, op_name, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  const char* device = TFE_OpGetDevice(other, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpSetDevice(ret, device, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
+  int num_inputs = TFE_OpGetFlatInputCount(other, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  for (int input_index = 0; input_index < num_inputs; ++input_index) {
+    TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
+    CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+    TFE_OpAddInput(ret, input, status);
+    CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  }
+  TF_DeleteStatus(status);
+  return ret;
+}
+
+TEST(CAPI, TestTFE_OpRecreation) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  // Clone an op with attributes and a device set.
+  TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
+  TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpSetDevice(original_var_op,
+                  "/job:localhost/replica:0/task:0/device:CPU:0", status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_Op* cloned = CloneOp(original_var_op);
+
+  EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
+            std::string(TFE_OpGetDevice(cloned, status)));
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  int num_retvals = 1;
+  TFE_TensorHandle* ret;
+  TFE_Execute(cloned, &ret, &num_retvals, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteTensorHandle(ret);
+
+  // Clone an op with inputs and no device set.
+  TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
+  TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
+  TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_TensorHandle* inputs[] = {input1, input2};
+  TFE_OpAddInputList(original_identity, inputs, 2, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_Op* cloned_identity = CloneOp(original_identity);
+  EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
+  TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
+  num_retvals = 2;
+  TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  TFE_DeleteTensorHandle(input1);
+  TFE_DeleteTensorHandle(input2);
+  TFE_DeleteTensorHandle(identity_ret[0]);
+  TFE_DeleteTensorHandle(identity_ret[1]);
+
+  TFE_DeleteOp(cloned_identity);
+  TFE_DeleteOp(original_identity);
+  TFE_DeleteOp(original_var_op);
+  TFE_DeleteOp(cloned);
+  TF_DeleteStatus(status);
+  TFE_DeleteContext(ctx);
+}
+
 }  // namespace
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
index 192f105..fd68866 100644
--- a/tensorflow/c/eager/c_api_test_util.cc
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -102,6 +102,32 @@
   return th;
 }
 
+TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
+                                                int64_t dims[], int num_dims) {
+  TF_Status* status = TF_NewStatus();
+  TF_Tensor* t =
+      TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
+  memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+  TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteTensor(t);
+  TF_DeleteStatus(status);
+  return th;
+}
+
+TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
+                                              int64_t dims[], int num_dims) {
+  TF_Status* status = TF_NewStatus();
+  TF_Tensor* t =
+      TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
+  memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+  TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteTensor(t);
+  TF_DeleteStatus(status);
+  return th;
+}
+
 TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
   constexpr int64_t dims[] = {100, 100};
   constexpr int num_elements = dims[0] * dims[1];
diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h
index fcf407a..2f77ae5 100644
--- a/tensorflow/c/eager/c_api_test_util.h
+++ b/tensorflow/c/eager/c_api_test_util.h
@@ -40,6 +40,14 @@
                                                   float data[], int64_t dims[],
                                                   int num_dims);
 
+// Get a Matrix TensorHandle with given float values and dimensions
+TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
+                                                int64_t dims[], int num_dims);
+
+// Get a Matrix TensorHandle with given int values and dimensions
+TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
+                                              int64_t dims[], int num_dims);
+
 // Return a tensor handle containing a 100x100 matrix of floats
 TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
 
diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc
index 1c078d4..b058c79 100644
--- a/tensorflow/c/eager/custom_device_test.cc
+++ b/tensorflow/c/eager/custom_device_test.cc
@@ -36,7 +36,8 @@
   bool arrived = false;
   bool executed = false;
   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
+  RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
+                        &arrived, &executed, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
   ASSERT_FALSE(arrived);
@@ -73,7 +74,8 @@
   bool executed = false;
   const char* custom_device_name =
       "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
+  RegisterLoggingDevice(context.get(), custom_device_name,
+                        /*strict_scope_placement=*/true, &arrived, &executed,
                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
@@ -103,7 +105,8 @@
   bool arrived = false;
   bool executed = false;
   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
+  RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
+                        &arrived, &executed, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   // Create a variable handle placed on the custom device.
@@ -187,7 +190,8 @@
   bool arrived = false;
   bool executed = false;
   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
+  RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
+                        &arrived, &executed, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
   // Create a variable handle placed on the custom device.
@@ -264,10 +268,12 @@
   const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
   bool arrived = false;
   bool executed = false;
-  RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
+  RegisterLoggingDevice(context.get(), custom0,
+                        /*strict_scope_placement=*/false, &arrived, &executed,
                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-  RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
+  RegisterLoggingDevice(context.get(), custom1,
+                        /*strict_scope_placement=*/true, &arrived, &executed,
                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
 
@@ -314,14 +320,34 @@
   ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
   ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
 
-  // Custom device: mix of custom/physical fails.
+  // Custom device: mix of custom/physical places the op on the custom device.
   matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
   num_retvals = 1;
+  executed = false;
   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
-  ASSERT_NE(TF_OK, TF_GetCode(status.get()));
-  ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
-  ASSERT_TRUE(
-      absl::StrContains(TF_Message(status.get()), "[]"));  // kVariantDeviceNull
+  EXPECT_TRUE(executed);
+  ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+  TFE_DeleteTensorHandle(retval);
+
+  // Explicit placement still forces the op onto the requested device
+  matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
+  TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
+                  status.get());
+  ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+  num_retvals = 1;
+  executed = false;
+  TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
+  EXPECT_FALSE(executed);
+  ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
+
+  // Custom devices can refuse to do type-based dispatch (as hcustom1 is
+  // configured to do)
+  matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
+  num_retvals = 1;
+  executed = false;
+  TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
+  EXPECT_FALSE(executed);
+  ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
 }
 
 TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
@@ -334,21 +360,24 @@
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
   bool arrived = false;
   bool executed = false;
-  RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
+  RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
+                        /*strict_scope_placement=*/true, &arrived, &executed,
                         status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
       << TF_Message(status.get());
 
   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
-  RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
+  RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
+                        &arrived, &executed, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
-  RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
+  RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
+                        &arrived, &executed, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
       << TF_Message(status.get());
 
-  RegisterLoggingDevice(context.get(),
-                        "/job:localhost/replica:0/task:0/device:CPU:0",
-                        &arrived, &executed, status.get());
+  RegisterLoggingDevice(
+      context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
+      /*strict_scope_placement=*/true, &arrived, &executed, status.get());
   ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
       << TF_Message(status.get());
 }
diff --git a/tensorflow/c/eager/custom_device_testutil.cc b/tensorflow/c/eager/custom_device_testutil.cc
index 28de366..014abe3 100644
--- a/tensorflow/c/eager/custom_device_testutil.cc
+++ b/tensorflow/c/eager/custom_device_testutil.cc
@@ -33,6 +33,9 @@
   bool* arrived_flag;
   // Set to true whenever an operation is executed
   bool* executed_flag;
+  // If true, only explicit op placements are accepted. If false, uses
+  // type-based dispatch.
+  bool strict_scope_placement;
 };
 
 struct LoggedTensor {
@@ -84,18 +87,35 @@
   return nullptr;
 }
 
-void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
-                          TFE_TensorHandle** inputs, const char* operation_name,
-                          const TFE_OpAttrs* attributes, int* num_outputs,
+void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
                           TFE_TensorHandle** outputs, TF_Status* s,
                           void* device_info) {
+  const char* requested_placement = TFE_OpGetDevice(original_op, s);
+  if (TF_GetCode(s) != TF_OK) return;
+
   LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
+  if (dev->strict_scope_placement && *requested_placement == '\0') {
+    TF_SetStatus(s, TF_INTERNAL,
+                 "Ops must be placed on the device explicitly, or their inputs "
+                 "first copied to other devices.");
+    return;
+  }
+  TFE_Context* context = TFE_OpGetContext(original_op, s);
+  if (TF_GetCode(s) != TF_OK) return;
+  const char* operation_name = TFE_OpGetName(original_op, s);
+  if (TF_GetCode(s) != TF_OK) return;
+  const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
+
   TFE_Op* op(TFE_NewOp(context, operation_name, s));
   if (TF_GetCode(s) != TF_OK) return;
   TFE_OpAddAttrs(op, attributes);
   TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
+  if (TF_GetCode(s) != TF_OK) return;
+  int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
+  if (TF_GetCode(s) != TF_OK) return;
   for (int j = 0; j < num_inputs; ++j) {
-    TFE_TensorHandle* input = inputs[j];
+    TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
+    if (TF_GetCode(s) != TF_OK) return;
     const char* input_device = TFE_TensorHandleDeviceName(input, s);
     if (TF_GetCode(s) != TF_OK) return;
     if (dev->device_name == input_device) {
@@ -131,8 +151,8 @@
 }  // namespace
 
 void RegisterLoggingDevice(TFE_Context* context, const char* name,
-                           bool* arrived_flag, bool* executed_flag,
-                           TF_Status* status) {
+                           bool strict_scope_placement, bool* arrived_flag,
+                           bool* executed_flag, TF_Status* status) {
   TFE_CustomDevice custom_device;
   custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
   custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
@@ -143,6 +163,7 @@
   device->executed_flag = executed_flag;
   device->device_name = name;
   device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+  device->strict_scope_placement = strict_scope_placement;
   TFE_RegisterCustomDevice(context, custom_device, name, device, status);
 }
 
@@ -168,5 +189,6 @@
   logging_device->device_name = name;
   logging_device->underlying_device =
       "/job:localhost/replica:0/task:0/device:CPU:0";
+  logging_device->strict_scope_placement = true;
   *device_info = reinterpret_cast<void*>(logging_device);
 }
diff --git a/tensorflow/c/eager/custom_device_testutil.h b/tensorflow/c/eager/custom_device_testutil.h
index 509df7d..a7c6008 100644
--- a/tensorflow/c/eager/custom_device_testutil.h
+++ b/tensorflow/c/eager/custom_device_testutil.h
@@ -25,8 +25,8 @@
 #include "tensorflow/c/tf_status.h"
 
 void RegisterLoggingDevice(TFE_Context* context, const char* name,
-                           bool* arrived_flag, bool* executed_flag,
-                           TF_Status* status);
+                           bool strict_scope_placement, bool* arrived_flag,
+                           bool* executed_flag, TF_Status* status);
 void AllocateLoggingDevice(const char* name, bool* arrived_flag,
                            bool* executed_flag, TFE_CustomDevice** device,
                            void** device_info);
diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc
index 9bcd0d0..89ff140 100644
--- a/tensorflow/c/eager/gradients.cc
+++ b/tensorflow/c/eager/gradients.cc
@@ -242,6 +242,7 @@
 Status Reset(AbstractOperation* op_, const char* op,
              const char* raw_device_name, ForwardOperation* forward_op_) {
   forward_op_->op_name = op;
+  forward_op_->attrs.Reset(op);
   return op_->Reset(op, raw_device_name);
 }
 Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
@@ -418,6 +419,11 @@
     // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
     forward_op_->outputs.push_back(retvals[i]);
   }
+  // TODO(b/166669239): This is needed to support AttrBuilder::Get for string
+  // attributes. Number type attrs and DataType attrs work fine without this.
+  // Consider getting rid of this and making the behavior between number types
+  // and string consistent.
+  forward_op_->attrs.BuildNodeDef();
   std::vector<TapeTensor> tape_tensors;
   for (auto t : retvals) {
     tape_tensors.push_back(TapeTensor(t, ctx));
diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc
index 80b1f15..56f0b84 100644
--- a/tensorflow/c/eager/gradients_test.cc
+++ b/tensorflow/c/eager/gradients_test.cc
@@ -507,6 +507,57 @@
   result_tensor = nullptr;
 }
 
+TEST_P(CppGradients, TestSetAttrString) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  AbstractTensorHandlePtr t;
+  {
+    AbstractTensorHandle* x_raw = nullptr;
+    Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    t.reset(x_raw);
+  }
+
+  AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
+  ForwardOperation forward_op;
+  forward_op.ctx = ctx.get();
+  Status s = Reset(check_numerics_op.get(), "CheckNumerics",
+                   /*raw_device_name=*/nullptr, &forward_op);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  if (isa<TracingOperation>(check_numerics_op.get())) {
+    s = dyn_cast<TracingOperation>(check_numerics_op.get())
+            ->SetOpName("check_numerics");
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  }
+  s = AddInput(check_numerics_op.get(), t.get(), &forward_op);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  string message = "This is the way!";
+  s = SetAttrString(check_numerics_op.get(), "message", message.data(),
+                    message.length(), &forward_op);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  int num_retvals = 1;
+  std::vector<AbstractTensorHandle*> outputs(1);
+  GradientRegistry registry;
+  std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
+  s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
+              &num_retvals, &forward_op, tape.get(), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  string read_message;
+  s = forward_op.attrs.Get("message", &read_message);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  ASSERT_EQ(read_message, message);
+}
+
 // TODO(b/164171226): Enable this test with tfrt after AddInputList is
 // supported. It is needed for IdentityN.
 #ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h
index ee212b2..7b68ec2 100644
--- a/tensorflow/c/eager/immediate_execution_operation.h
+++ b/tensorflow/c/eager/immediate_execution_operation.h
@@ -47,9 +47,6 @@
   virtual Status InputLength(const char* input_name, int* length) = 0;
   virtual Status OutputLength(const char* output_name, int* length) = 0;
 
-  // Experimental
-  virtual Status SetUseXla(bool enable) = 0;
-
   // Set stack trace to be used for potential async error reporting.
   virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
 
diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc
new file mode 100644
index 0000000..1f04e25
--- /dev/null
+++ b/tensorflow/c/eager/mnist_gradients_test.cc
@@ -0,0 +1,781 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+
+#include "absl/types/span.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/c/eager/c_api_unified_experimental.h"
+#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
+#include "tensorflow/c/eager/gradients.h"
+#include "tensorflow/c/eager/gradients_internal.h"
+#include "tensorflow/c/eager/mnist_gradients_testutil.h"
+#include "tensorflow/c/experimental/gradients/math_grad.h"
+#include "tensorflow/c/experimental/gradients/nn_grad.h"
+#include "tensorflow/c/experimental/ops/array_ops.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/c/tf_tensor.h"
+#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace gradients {
+namespace internal {
+namespace {
+
+class CppGradients
+    : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
+ protected:
+  void SetUp() override {
+    TF_SetTracingImplementation(std::get<0>(GetParam()));
+  }
+};
+
+Status RegisterGradients(GradientRegistry* registry) {
+  TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
+  TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
+  TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
+  TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
+  TF_RETURN_IF_ERROR(
+      registry->Register("SparseSoftmaxCrossEntropyWithLogits",
+                         SparseSoftmaxCrossEntropyLossRegisterer));
+  return Status::OK();
+}
+
+// ========================= Test Util Functions ==============================
+
+// Get a scalar TensorHandle with given value
+Status TestScalarTensorHandle(AbstractContext* ctx, float value,
+                              AbstractTensorHandle** tensor) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_Context* eager_ctx =
+      TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
+  TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
+  *tensor =
+      unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
+  return Status::OK();
+}
+
+// Get a Matrix TensorHandle with given float values and dimensions
+Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
+                                     int64_t dims[], int num_dims,
+                                     AbstractTensorHandle** tensor) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_Context* eager_ctx =
+      TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
+  TFE_TensorHandle* input_eager =
+      TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
+  *tensor =
+      unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
+  return Status::OK();
+}
+
+// Get a Matrix TensorHandle with given int values and dimensions
+Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
+                                   int64_t dims[], int num_dims,
+                                   AbstractTensorHandle** tensor) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_Context* eager_ctx =
+      TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
+  TFE_TensorHandle* input_eager =
+      TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
+  *tensor =
+      unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
+  return Status::OK();
+}
+
+Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_TensorHandle* result_t =
+      TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
+  *result_tensor = TFE_TensorHandleResolve(result_t, status.get());
+  return Status::OK();
+}
+
+AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
+                                                 float vals[], int64_t dims[],
+                                                 int num_dims) {
+  AbstractTensorHandlePtr A;
+  AbstractTensorHandle* a_raw = nullptr;
+  Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
+  A.reset(a_raw);
+  return A;
+}
+
+AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
+                                               int64_t dims[], int num_dims) {
+  AbstractTensorHandlePtr A;
+  AbstractTensorHandle* a_raw = nullptr;
+  Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
+  A.reset(a_raw);
+  return A;
+}
+
+// =========================== Start Tests ================================
+
+TEST_P(CppGradients, TestMatMulGrad) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
+  int64_t A_dims[] = {2, 2};
+  float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
+  int64_t B_dims[] = {2, 2};
+  int num_dims = 2;
+
+  AbstractTensorHandlePtr A =
+      GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
+  AbstractTensorHandlePtr B =
+      GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
+
+  GradientRegistry registry;
+  Status s = RegisterGradients(&registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  /* Pseudo-code:
+   *
+   * tape.watch(A)
+   * tape.watch(B)
+   * Y = AB
+   * outputs = tape.gradient(Y, [A, B])
+   */
+
+  std::vector<AbstractTensorHandle*> outputs(2);
+  s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
+               absl::MakeSpan(outputs),
+               /*use_function=*/!std::get<2>(GetParam()), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  TF_Tensor* dA_tensor;
+  s = GetValue(outputs[0], &dA_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[4] = {0};
+  memcpy(&result_data[0], TF_TensorData(dA_tensor),
+         TF_TensorByteSize(dA_tensor));
+
+  float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
+  float tolerance = 1e-3;
+  for (int j = 0; j < 4; j++) {
+    ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
+  }
+
+  TF_Tensor* dB_tensor;
+  s = GetValue(outputs[1], &dB_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  memcpy(&result_data[0], TF_TensorData(dB_tensor),
+         TF_TensorByteSize(dB_tensor));
+
+  float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
+  for (int j = 0; j < 4; j++) {
+    ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
+  }
+
+  outputs[0]->Unref();
+  outputs[1]->Unref();
+  TF_DeleteTensor(dA_tensor);
+  TF_DeleteTensor(dB_tensor);
+}
+
+TEST_P(CppGradients, TestMNISTForward) {
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  // X = data
+  float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
+  int64_t dims[] = {2, 2};
+  int num_dims = 2;
+  AbstractTensorHandlePtr X =
+      GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
+
+  // W1 = first weights
+  float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
+  AbstractTensorHandlePtr W1 =
+      GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
+
+  // W2 = second weights
+  float W2_vals[] = {.1f, .2f, .3f, -.5f};
+  AbstractTensorHandlePtr W2 =
+      GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
+
+  // y = labels
+  int y_vals[] = {1, 1};
+  int64_t dims_y[] = {2};
+  num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
+  AbstractTensorHandlePtr y =
+      GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
+
+  GradientRegistry registry;
+
+  // Run the Forward Pass
+  std::vector<AbstractTensorHandle*> outputs(2);
+  Status s =
+      RunModel(MNISTForwardModel, ctx.get(),
+               {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
+               /*use_function=*/!std::get<2>(GetParam()), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  // Verify the Results
+  TF_Tensor* scores_tensor;
+  s = GetValue(outputs[0], &scores_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[4] = {0};
+  memcpy(&result_data[0], TF_TensorData(scores_tensor),
+         TF_TensorByteSize(scores_tensor));
+
+  float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
+  float tolerance = 1e-3;
+  for (int j = 0; j < 4; j++) {
+    ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
+  }
+
+  TF_Tensor* loss_vals_tensor;
+  s = GetValue(outputs[1], &loss_vals_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
+         TF_TensorByteSize(loss_vals_tensor));
+  float expected_losses[2] = {9.6f, 27.2f};
+  for (int j = 0; j < 2; j++) {
+    ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
+  }
+
+  outputs[0]->Unref();
+  outputs[1]->Unref();
+  TF_DeleteTensor(scores_tensor);
+  TF_DeleteTensor(loss_vals_tensor);
+}
+
+TEST_P(CppGradients, TestMNISTForward2) {
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  // X = data
+  float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  int64_t X_dims[] = {3, 2};
+  int num_dims = 2;
+  AbstractTensorHandlePtr X =
+      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
+
+  // W1 = first weights
+  float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
+  int64_t dims[] = {2, 2};
+  AbstractTensorHandlePtr W1 =
+      GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
+
+  // W2 = second weights
+  float W2_vals[] = {.1f, .2f, .3f, -.5f};
+  AbstractTensorHandlePtr W2 =
+      GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
+
+  // y = labels
+  int y_vals[] = {1, 1, 1};
+  int64_t y_dims[] = {3};
+  num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
+  AbstractTensorHandlePtr y =
+      GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
+
+  GradientRegistry registry;
+
+  // Run the Forward Pass
+  std::vector<AbstractTensorHandle*> outputs(2);
+  Status s =
+      RunModel(MNISTForwardModel, ctx.get(),
+               {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
+               /*use_function=*/!std::get<2>(GetParam()), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  // Verify the Results
+  TF_Tensor* scores_tensor;
+  s = GetValue(outputs[0], &scores_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[6] = {0};
+  memcpy(&result_data[0], TF_TensorData(scores_tensor),
+         TF_TensorByteSize(scores_tensor));
+
+  float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
+  float tolerance = 1e-3;
+  for (int j = 0; j < 6; j++) {
+    ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
+  }
+
+  TF_Tensor* loss_vals_tensor;
+  s = GetValue(outputs[1], &loss_vals_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
+         TF_TensorByteSize(loss_vals_tensor));
+  float expected_losses[3] = {9.6f, 27.2f, 44.8f};
+  for (int j = 0; j < 3; j++) {
+    ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
+  }
+
+  outputs[0]->Unref();
+  outputs[1]->Unref();
+  TF_DeleteTensor(scores_tensor);
+  TF_DeleteTensor(loss_vals_tensor);
+}
+
+TEST_P(CppGradients, TestMatMulTranspose) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  // X = data
+  float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  int64_t X_dims[] = {2, 3};
+  int num_dims = 2;
+  AbstractTensorHandlePtr X =
+      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
+
+  // W1 = first weights
+  float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
+  int64_t dims[] = {2, 2};
+  AbstractTensorHandlePtr W1 =
+      GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
+
+  GradientRegistry registry;
+
+  // Run the MatMul Op
+  std::vector<AbstractTensorHandle*> outputs(1);
+
+  Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
+                      absl::MakeSpan(outputs),
+                      /*use_function=*/!std::get<2>(GetParam()), registry);
+
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  // Verify the Results
+  TF_Tensor* scores_tensor;
+  s = GetValue(outputs[0], &scores_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[6] = {0};
+  memcpy(&result_data[0], TF_TensorData(scores_tensor),
+         TF_TensorByteSize(scores_tensor));
+
+  float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
+  float tolerance = 1e-3;
+  for (int j = 0; j < 6; j++) {
+    ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
+  }
+}
+
+TEST_P(CppGradients, TestReluGrad) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  // X = data
+  float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
+  int64_t X_dims[] = {3, 3};
+  int num_dims = 2;
+  AbstractTensorHandlePtr X =
+      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
+
+  GradientRegistry registry;
+  Status s = RegisterGradients(&registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  /* Pseudo-code:
+   *
+   * tape.watch(X)
+   * Y = Relu(X)
+   * outputs = tape.gradient(Y, [X])
+   */
+  std::vector<AbstractTensorHandle*> outputs(1);
+  s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs),
+               /*use_function=*/!std::get<2>(GetParam()), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  TF_Tensor* dX_tensor;
+  s = GetValue(outputs[0], &dX_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[9] = {0};
+  memcpy(&result_data[0], TF_TensorData(dX_tensor),
+         TF_TensorByteSize(dX_tensor));
+
+  float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
+  float tolerance = 1e-3;
+  for (int j = 0; j < 9; j++) {
+    ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
+  }
+
+  outputs[0]->Unref();
+  TF_DeleteTensor(dX_tensor);
+}
+
+TEST_P(CppGradients, TestSoftmaxLossGrad) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  // X = scores
+  float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
+  int64_t X_dims[] = {3, 3};
+  int num_dims = 2;
+  AbstractTensorHandlePtr X =
+      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
+
+  // y = labels
+  int y_vals[] = {1, 0, 1};
+  int64_t y_dims[] = {3};
+  num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
+  AbstractTensorHandlePtr y =
+      GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
+
+  GradientRegistry registry;
+  Status s = RegisterGradients(&registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  /* Pseudo-code:
+   *
+   * tape.watch(X)
+   * tape.watch(labels)
+   * loss = SoftmaxLoss(X, labels)
+   * outputs = tape.gradient(loss, [X, labels])
+   *
+   *
+   */
+
+  std::vector<AbstractTensorHandle*> outputs(2);
+  s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()},
+               absl::MakeSpan(outputs),
+               /*use_function=*/!std::get<2>(GetParam()), registry);
+
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  TF_Tensor* dX_tensor;
+  s = GetValue(outputs[0], &dX_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[9] = {0};
+  memcpy(&result_data[0], TF_TensorData(dX_tensor),
+         TF_TensorByteSize(dX_tensor));
+
+  float expected_dX[9] = {0.090f,  -0.7553f, 0.6652f,  -0.9099f, 0.2447f,
+                          0.6652f, 0.8437f,  -0.8858f, 0.0420f};
+  float tolerance = 1e-3;
+  for (int j = 0; j < 9; j++) {
+    ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
+  }
+
+  // Only Unref() first output as 2nd is nullptr grad for labels
+  outputs[0]->Unref();
+  TF_DeleteTensor(dX_tensor);
+}
+
+TEST_P(CppGradients, TestMNISTGrad) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  // X = data
+  float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
+  int64_t X_dims[] = {2, 2};
+  int num_dims = 2;
+  AbstractTensorHandlePtr X =
+      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
+
+  // W1 = first weights
+  float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
+  int64_t dims[] = {2, 2};
+  AbstractTensorHandlePtr W1 =
+      GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
+
+  // W2 = second weights
+  float W2_vals[] = {.1f, .2f, .3f, -.5f};
+  AbstractTensorHandlePtr W2 =
+      GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
+
+  // y = labels
+  int y_vals[] = {1, 1};
+  int64_t y_dims[] = {2};
+  num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
+  AbstractTensorHandlePtr y =
+      GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
+
+  // Register Grads
+  GradientRegistry registry;
+  Status s = RegisterGradients(&registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  /* Pseudo-code:
+   *
+   *
+   * tape.watch(W1)
+   * tape.watch(W2)
+   * mm = X*W1
+   * hidden = Relu(mm)
+   * scores = W2*hidden
+   * loss = SoftmaxLoss(scores, y)
+   * outputs = tape.gradient(loss, [A, B])
+   *
+   */
+
+  std::vector<AbstractTensorHandle*> outputs(3);
+  s = RunModel(MNISTGradModel, ctx.get(),
+               {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
+               /*use_function=*/!std::get<2>(GetParam()), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float tolerance = 1e-3;
+  TF_Tensor* dW1_tensor;
+  s = GetValue(outputs[0], &dW1_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[4] = {0};
+  memcpy(&result_data[0], TF_TensorData(dW1_tensor),
+         TF_TensorByteSize(dW1_tensor));
+
+  float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
+  ;  // dLoss
+  for (int j = 0; j < 4; j++) {
+    ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
+  }
+
+  TF_Tensor* dW2_tensor;
+  s = GetValue(outputs[1], &dW2_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  memcpy(&result_data[0], TF_TensorData(dW2_tensor),
+         TF_TensorByteSize(dW2_tensor));
+
+  float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f};  // dLoss
+  for (int j = 0; j < 4; j++) {
+    ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
+  }
+
+  outputs[0]->Unref();
+  outputs[1]->Unref();
+  outputs[2]->Unref();
+  TF_DeleteTensor(dW1_tensor);
+  TF_DeleteTensor(dW2_tensor);
+}
+
+TEST_P(CppGradients, TestScalarMul) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  AbstractTensorHandlePtr eta;
+  {
+    AbstractTensorHandle* x_raw = nullptr;
+    Status s = TestScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    eta.reset(x_raw);
+  }
+
+  float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
+  int64_t A_dims[] = {2, 2};
+  int num_dims = 2;
+
+  AbstractTensorHandlePtr A =
+      GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
+
+  GradientRegistry registry;
+  std::vector<AbstractTensorHandle*> outputs(1);
+  Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
+                      absl::MakeSpan(outputs),
+                      /*use_function=*/!std::get<2>(GetParam()), registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  TF_Tensor* dA_tensor;
+  s = GetValue(outputs[0], &dA_tensor);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  float result_data[4] = {0};
+  memcpy(&result_data[0], TF_TensorData(dA_tensor),
+         TF_TensorByteSize(dA_tensor));
+
+  float tolerance = 1e-3;
+  float eta_val = 1.5f;
+  for (int j = 0; j < 4; j++) {
+    ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
+  }
+
+  outputs[0]->Unref();
+  TF_DeleteTensor(dA_tensor);
+}
+
+TEST_P(CppGradients, TestMNIST_Training) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+
+  AbstractContextPtr ctx;
+  {
+    AbstractContext* ctx_raw = nullptr;
+    Status s =
+        BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+    ctx.reset(ctx_raw);
+  }
+
+  // X = data
+  float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
+  int64_t X_dims[] = {2, 2};
+  int num_dims = 2;
+  AbstractTensorHandlePtr X =
+      GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
+
+  // TODO(amturati): use random initializer for weights instead of
+  // constant values.
+
+  // W1 = first weights
+  float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
+  int64_t dims[] = {2, 2};
+  AbstractTensorHandlePtr W1 =
+      GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
+
+  // W2 = second weights
+  float W2_vals[] = {.1f, .2f, .3f, -.5f};
+  AbstractTensorHandlePtr W2 =
+      GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
+
+  // y = labels
+  int y_vals[] = {1, 1};
+  int64_t y_dims[] = {2};
+  num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
+  AbstractTensorHandlePtr y =
+      GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
+
+  // Register Grads
+  GradientRegistry registry;
+  Status s = RegisterGradients(&registry);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  // Prepare for training
+  std::vector<AbstractTensorHandle*> weights;
+  weights.push_back(W1.get());
+  weights.push_back(W2.get());
+
+  // Set learning rate to be 1e-1
+  AbstractTensorHandle* learning_rate = nullptr;
+  s = TestScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
+  ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+  // Train
+  int num_iters = 10;
+  std::vector<AbstractTensorHandle*> mnist_outputs(3);
+  std::vector<AbstractTensorHandle*> grads(2);
+  for (int i = 0; i < num_iters; i++) {
+    // Run Forward Pass
+    s = RunModel(MNISTGradModel, ctx.get(),
+                 {X.get(), weights[0], weights[1], y.get()},
+                 absl::MakeSpan(mnist_outputs),
+                 /*use_function=*/!std::get<2>(GetParam()), registry);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+
+    // Fill grads
+    grads[0] = mnist_outputs[0];
+    grads[1] = mnist_outputs[1];
+
+    // Gradient Update
+    s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
+    ASSERT_EQ(errors::OK, s.code()) << s.error_message();
+  }
+
+  grads[0]->Unref();          // release W1_grad
+  grads[1]->Unref();          // release W2_grad
+  mnist_outputs[2]->Unref();  // release loss
+}
+
+#ifdef PLATFORM_GOOGLE
+INSTANTIATE_TEST_SUITE_P(
+    UnifiedCAPI, CppGradients,
+    ::testing::Combine(::testing::Values("graphdef", "mlir"),
+                       /*tfrt*/ ::testing::Values(false),
+                       /*executing_eagerly*/ ::testing::Values(true, false)));
+#else
+INSTANTIATE_TEST_SUITE_P(
+    UnifiedCAPI, CppGradients,
+    ::testing::Combine(::testing::Values("graphdef", "mlir"),
+                       /*tfrt*/ ::testing::Values(false),
+                       /*executing_eagerly*/ ::testing::Values(true, false)));
+#endif
+}  // namespace
+}  // namespace internal
+}  // namespace gradients
+}  // namespace tensorflow
diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc
new file mode 100644
index 0000000..9f5d0d1
--- /dev/null
+++ b/tensorflow/c/eager/mnist_gradients_testutil.cc
@@ -0,0 +1,603 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/c/eager/mnist_gradients_testutil.h"
+
+#include <memory>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/types/span.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/eager/c_api_unified_experimental.h"
+#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
+#include "tensorflow/c/eager/gradients.h"
+#include "tensorflow/c/eager/gradients_internal.h"
+#include "tensorflow/c/experimental/ops/array_ops.h"
+#include "tensorflow/c/experimental/ops/math_ops.h"
+#include "tensorflow/c/experimental/ops/nn_ops.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/c/tf_tensor.h"
+#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
+
+// ========================== Tape Ops ==============================
+
+namespace tensorflow {
+namespace gradients {
+namespace internal {
+
+using std::vector;
+using tensorflow::tracing::TracingOperation;
+
+// Computes `inputs[0] + inputs[1]` and records it on the tape.
+Status Add(AbstractContext* ctx, Tape* tape,
+           absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs,
+           const GradientRegistry& registry) {
+  AbstractOperationPtr add_op(ctx->CreateOperation());
+  ForwardOperation forward_op;
+  forward_op.ctx = ctx;
+  TF_RETURN_IF_ERROR(
+      Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
+  if (isa<TracingOperation>(add_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
+  }
+  TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
+  TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
+  int num_retvals = 1;
+  return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
+                 registry);
+}
+
+// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
+Status MatMul(AbstractContext* ctx, Tape* tape,
+              absl::Span<AbstractTensorHandle* const> inputs,
+              absl::Span<AbstractTensorHandle*> outputs, const char* name,
+              bool transpose_a, bool transpose_b,
+              const GradientRegistry& registry) {
+  AbstractOperationPtr matmul_op(ctx->CreateOperation());
+  ForwardOperation forward_op;
+  forward_op.ctx = ctx;
+  TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
+                           /*raw_device_name=*/nullptr, &forward_op));
+  if (isa<TracingOperation>(matmul_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
+  }
+
+  TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
+  TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
+  TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
+      matmul_op.get(), "transpose_a", transpose_a, &forward_op));
+  TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
+      matmul_op.get(), "transpose_b", transpose_b, &forward_op));
+
+  int num_retvals = 1;
+  return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
+                 registry);
+}
+
+Status Mul(AbstractContext* ctx, Tape* tape,
+           absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs, const char* name,
+           const GradientRegistry& registry) {
+  AbstractOperationPtr mul_op(ctx->CreateOperation());
+  ForwardOperation forward_op;
+  forward_op.ctx = ctx;
+  TF_RETURN_IF_ERROR(
+      Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
+  if (isa<TracingOperation>(mul_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
+  }
+
+  TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
+  TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
+
+  int num_retvals = 1;
+  return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
+                 registry);
+}
+
+// Computes `Relu(inputs[0])` and records it on the tape.
+Status Relu(AbstractContext* ctx, Tape* tape,
+            absl::Span<AbstractTensorHandle* const> inputs,
+            absl::Span<AbstractTensorHandle*> outputs, const char* name,
+            const GradientRegistry& registry) {
+  AbstractOperationPtr relu_op(ctx->CreateOperation());
+  ForwardOperation forward_op;
+  forward_op.ctx = ctx;
+  TF_RETURN_IF_ERROR(
+      Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
+  if (isa<TracingOperation>(relu_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
+  }
+  TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
+  int num_retvals = 1;
+  return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
+                 registry);
+}
+
+// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
+// tape.
+Status SparseSoftmaxCrossEntropyLoss(
+    AbstractContext* ctx, Tape* tape,
+    absl::Span<AbstractTensorHandle* const> inputs,
+    absl::Span<AbstractTensorHandle*> outputs, const char* name,
+    const GradientRegistry& registry) {
+  AbstractTensorHandle* scores = inputs[0];
+  AbstractTensorHandle* labels = inputs[1];
+
+  AbstractOperationPtr sm_op(ctx->CreateOperation());
+  ForwardOperation forward_op;
+  forward_op.ctx = ctx;
+  TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
+                           /*raw_device_name=*/nullptr, &forward_op));
+  if (isa<TracingOperation>(sm_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
+  }
+
+  TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
+  TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
+
+  int num_retvals = 2;  // returns loss values and backprop
+  return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
+                 registry);
+}
+
+//===================== Test Models to run =========================
+
+// Computes
+// y = inputs[0] + inputs[1]
+// return grad(y, {inputs[0], inputs[1]})
+Status AddGradModel(AbstractContext* ctx,
+                    absl::Span<AbstractTensorHandle* const> inputs,
+                    absl::Span<AbstractTensorHandle*> outputs,
+                    const GradientRegistry& registry) {
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  tape->Watch(ToId(inputs[0]));  // Watch x.
+  tape->Watch(ToId(inputs[1]));  // Watch y.
+  std::vector<AbstractTensorHandle*> add_outputs(1);
+  TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
+                         registry));  // Compute x+y.
+  std::unordered_map<tensorflow::int64, TapeTensor>
+      source_tensors_that_are_targets;
+
+  std::vector<AbstractTensorHandle*> out_grads;
+  TF_RETURN_IF_ERROR(tape->ComputeGradient(
+      vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
+      /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
+      source_tensors_that_are_targets,
+      /*output_gradients=*/{}, &out_grads,
+      /*build_default_zeros_grads=*/false));
+  for (auto add_output : add_outputs) {
+    add_output->Unref();
+  }
+  outputs[0] = out_grads[0];
+  outputs[1] = out_grads[1];
+  delete tape;
+  return Status::OK();
+}
+
+// Computes
+// y = inputs[0] * inputs[1]
+// return grad(y, {inputs[0], inputs[1]})
+Status MatMulGradModel(AbstractContext* ctx,
+                       absl::Span<AbstractTensorHandle* const> inputs,
+                       absl::Span<AbstractTensorHandle*> outputs,
+                       const GradientRegistry& registry) {
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  tape->Watch(ToId(inputs[0]));  // Watch x.
+  tape->Watch(ToId(inputs[1]));  // Watch y.
+  vector<AbstractTensorHandle*> mm_outputs(1);
+  TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
+                            "matmul0", /*transpose_a=*/false,
+                            /*transpose_b=*/false, registry));  // Compute x*y.
+
+  std::unordered_map<tensorflow::int64, TapeTensor>
+      source_tensors_that_are_targets;
+
+  vector<AbstractTensorHandle*> out_grads;
+  TF_RETURN_IF_ERROR(tape->ComputeGradient(
+      vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
+      /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
+      source_tensors_that_are_targets,
+      /*output_gradients=*/{}, &out_grads,
+      /*build_default_zeros_grads=*/false));
+  for (auto mm_output : mm_outputs) {
+    mm_output->Unref();
+  }
+  outputs[0] = out_grads[0];
+  outputs[1] = out_grads[1];
+  delete tape;
+  return Status::OK();
+}
+
+// Model to run 2-layer net
+Status MNISTForwardModel(AbstractContext* ctx,
+                         absl::Span<AbstractTensorHandle* const> inputs,
+                         absl::Span<AbstractTensorHandle*> outputs,
+                         const GradientRegistry& registry) {
+  /**
+   * We will trace a 2-layer fully connected network for an MNIST model:
+   *
+   *   def mnist_forward(X, W1, W2, y_labels):
+   *     mm_out_1 = tf.matmul(X,W1)
+   *     hidden_layer = tf.nn.relu(mm_out_1)
+   *     scores = tf.matmul(hidden_layer,W2)
+   *     softmax =
+   * tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) return
+   * scores, softmax
+   *
+   * Use this convention for inputs:
+   *
+   *   inputs = [X, W1, W2, y_labels]
+   *
+   */
+  AbstractTensorHandle* X = inputs[0];
+  AbstractTensorHandle* W1 = inputs[1];
+  AbstractTensorHandle* W2 = inputs[2];
+  AbstractTensorHandle* y_labels = inputs[3];
+
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  tape->Watch(ToId(W1));  // Watch W1.
+  tape->Watch(ToId(W2));  // Watch W2.
+  vector<AbstractTensorHandle*> temp_outputs(1);
+
+  TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
+                            "matmul0", /*transpose_a=*/false,
+                            /*transpose_b=*/false, registry));  // Compute X*W1
+
+  TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
+                          absl::MakeSpan(temp_outputs), "relu",
+                          registry));  // Compute Relu(X*W1)
+
+  TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
+                            absl::MakeSpan(temp_outputs), "matmul1",
+                            /*transpose_a=*/false, /*transpose_b=*/false,
+                            registry));  // Compute W2*Relu(X*W1)
+
+  AbstractTensorHandle* scores = temp_outputs[0];
+
+  temp_outputs.resize(2);
+  TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
+      ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
+      "softmax_loss", registry));  // Compute Softmax(Scores,labels)
+
+  AbstractTensorHandle* loss_vals = temp_outputs[0];
+
+  outputs[0] = scores;
+  outputs[1] = loss_vals;
+  delete tape;
+  return Status::OK();
+}
+
+Status MatMulTransposeModel(AbstractContext* ctx,
+                            absl::Span<AbstractTensorHandle* const> inputs,
+                            absl::Span<AbstractTensorHandle*> outputs,
+                            const GradientRegistry& registry) {
+  AbstractTensorHandle* X = inputs[0];
+  AbstractTensorHandle* W1 = inputs[1];
+
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  tape->Watch(ToId(X));
+  tape->Watch(ToId(W1));
+  vector<AbstractTensorHandle*> temp_outputs(1);
+
+  TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
+                            "matmul0", /*transpose_a=*/true,
+                            /*transpose_b=*/false, registry));  // Compute X*W1
+
+  outputs[0] = temp_outputs[0];
+
+  delete tape;
+  return Status::OK();
+}
+
+Status ReluGradModel(AbstractContext* ctx,
+                     absl::Span<AbstractTensorHandle* const> inputs,
+                     absl::Span<AbstractTensorHandle*> outputs,
+                     const GradientRegistry& registry) {
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  tape->Watch(ToId(inputs[0]));  // Watch X
+  vector<AbstractTensorHandle*> relu_outputs(1);
+  TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
+                          "relu0", registry));  // Relu(X)
+
+  std::unordered_map<tensorflow::int64, TapeTensor>
+      source_tensors_that_are_targets;
+
+  vector<AbstractTensorHandle*> out_grads;
+  TF_RETURN_IF_ERROR(tape->ComputeGradient(
+      vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])},
+      /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
+      /*output_gradients=*/{}, &out_grads,
+      /*build_default_zeros_grads=*/false));
+
+  for (auto relu_output : relu_outputs) {
+    relu_output->Unref();
+  }
+
+  outputs[0] = out_grads[0];
+  delete tape;
+  return Status::OK();
+}
+
+Status SoftmaxLossGradModel(AbstractContext* ctx,
+                            absl::Span<AbstractTensorHandle* const> inputs,
+                            absl::Span<AbstractTensorHandle*> outputs,
+                            const GradientRegistry& registry) {
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  tape->Watch(ToId(inputs[0]));  // Watch scores.
+  tape->Watch(ToId(inputs[1]));  // Watch labels.
+  vector<AbstractTensorHandle*> sm_outputs(2);
+  TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
+      ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
+
+  std::unordered_map<tensorflow::int64, TapeTensor>
+      source_tensors_that_are_targets;
+
+  vector<AbstractTensorHandle*> out_grads;
+  TF_RETURN_IF_ERROR(tape->ComputeGradient(
+      vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])},
+      /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
+      source_tensors_that_are_targets,
+      /*output_gradients=*/{}, &out_grads,
+      /*build_default_zeros_grads=*/false));
+
+  outputs[0] = out_grads[0];
+  outputs[1] = out_grads[1];
+  delete tape;
+  return Status::OK();
+}
+
+Status MNISTGradModel(AbstractContext* ctx,
+                      absl::Span<AbstractTensorHandle* const> inputs,
+                      absl::Span<AbstractTensorHandle*> outputs,
+                      const GradientRegistry& registry) {
+  AbstractTensorHandle* X = inputs[0];
+  AbstractTensorHandle* W1 = inputs[1];
+  AbstractTensorHandle* W2 = inputs[2];
+  AbstractTensorHandle* y_labels = inputs[3];
+
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/true);
+  tape->Watch(ToId(X));   // Watch X.
+  tape->Watch(ToId(W1));  // Watch W1.
+  tape->Watch(ToId(W2));  // Watch W1.
+  vector<AbstractTensorHandle*> temp_outputs(1);
+  TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
+                            "matmul0", /*transpose_a=*/false,
+                            /*transpose_b=*/false, registry));  // Compute X*W1
+
+  AbstractTensorHandle* mm = temp_outputs[0];
+
+  TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
+                          absl::MakeSpan(temp_outputs),  // Relu(X*W1)
+                          "relu0", registry));
+
+  AbstractTensorHandle* hidden = temp_outputs[0];
+
+  TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
+                            absl::MakeSpan(temp_outputs), "matmul1",
+                            /*transpose_a=*/false, /*transpose_b=*/false,
+                            registry));  // W2*Relu(X*W1)
+
+  AbstractTensorHandle* scores = temp_outputs[0];
+
+  temp_outputs.resize(2);
+  TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
+      ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
+      "softmaxloss", registry));  // W2*Relu(X*W1)
+
+  AbstractTensorHandle* loss = temp_outputs[0];
+
+  std::unordered_map<tensorflow::int64, TapeTensor>
+      source_tensors_that_are_targets;
+
+  vector<AbstractTensorHandle*> out_grads;
+  TF_RETURN_IF_ERROR(
+      tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)},
+                            /*source_tensor_ids=*/{ToId(W1), ToId(W2)},
+                            source_tensors_that_are_targets,
+                            /*output_gradients=*/{}, &out_grads,
+                            /*build_default_zeros_grads=*/false));
+
+  // Only release 2nd temp output as first holds loss values.
+  temp_outputs[1]->Unref();
+
+  outputs[0] = out_grads[0];  // dW1
+  outputs[1] = out_grads[1];  // dW2
+  outputs[2] = loss;
+
+  delete tape;
+  return Status::OK();
+}
+
+Status ScalarMulModel(AbstractContext* ctx,
+                      absl::Span<AbstractTensorHandle* const> inputs,
+                      absl::Span<AbstractTensorHandle*> outputs,
+                      const GradientRegistry& registry) {
+  AbstractTensorHandle* eta = inputs[0];
+  AbstractTensorHandle* A = inputs[1];
+
+  TapeVSpace vspace(ctx);
+  auto tape = new Tape(/*persistent=*/false);
+  vector<AbstractTensorHandle*> temp_outputs(1);
+
+  TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
+                         "scalarMul0", registry));  // Compute eta*A
+
+  outputs[0] = temp_outputs[0];
+
+  delete tape;
+  return Status::OK();
+}
+
+// ============================= End Models ================================
+
+Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
+                     vector<AbstractTensorHandle*>& weights,
+                     AbstractTensorHandle* learning_rate) {
+  /* Update weights one by one using gradient update rule:
+   *
+   *    w -= lr*grad[w]
+   *
+   *  NOTE: assuming learning rate is positive
+   */
+
+  Status s;
+  int num_grads = grads.size();
+  vector<AbstractTensorHandle*> temp_outputs(1);
+  std::string update_str;
+
+  // Negate learning rate for gradient descent
+  TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
+                              absl::MakeSpan(temp_outputs),
+                              "neg_lr"));  // Compute -lr
+  learning_rate = temp_outputs[0];
+
+  for (int i = 0; i < num_grads; i++) {
+    // Compute dW = -lr * grad(w[i])
+    update_str = "update_mul_" + std::to_string(i);
+    s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs),
+                 update_str.c_str());
+
+    AbstractTensorHandle* dW = temp_outputs[0];
+
+    // Compute temp = weights[i] + dW
+    update_str = "update_add_" + std::to_string(i);
+    s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs),
+                 update_str.c_str());
+
+    // Update the weights
+    weights[i] = temp_outputs[0];
+  }
+
+  return Status::OK();
+}
+
+AbstractContext* BuildFunction(const char* fn_name) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
+  return unwrap(graph_ctx);
+}
+
+Status CreateParamsForInputs(AbstractContext* ctx,
+                             absl::Span<AbstractTensorHandle* const> inputs,
+                             vector<AbstractTensorHandle*>* params) {
+  tracing::TracingTensorHandle* handle = nullptr;
+  for (auto input : inputs) {
+    TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
+        input->DataType(), &handle));
+    params->emplace_back(handle);
+  }
+  return Status::OK();
+}
+
+Status RunModel(Model model, AbstractContext* ctx,
+                absl::Span<AbstractTensorHandle* const> inputs,
+                absl::Span<AbstractTensorHandle*> outputs, bool use_function,
+                const GradientRegistry& registry) {
+  if (use_function) {
+    const char* fn_name = "test_fn";
+    std::unique_ptr<AbstractFunction> scoped_func;
+    // Returning null tensors from a tf.function is not supported, so we keep
+    // track of indices in the model's outputs are nullptr in this set.
+    // The FunctionDef only outputs the non-null tensors. We later pad the
+    // function op outputs to have nullptrs at the `null_indices`.
+    absl::flat_hash_set<int> null_indices;
+    {
+      AbstractContextPtr func_ctx(BuildFunction(fn_name));
+      vector<AbstractTensorHandle*> func_inputs;
+      func_inputs.reserve(inputs.size());
+      TF_RETURN_IF_ERROR(
+          CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
+      vector<AbstractTensorHandle*> model_outputs;
+      model_outputs.resize(outputs.size());
+      TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
+                               absl::MakeSpan(model_outputs), registry));
+      for (auto func_input : func_inputs) {
+        func_input->Unref();
+      }
+      AbstractFunction* func = nullptr;
+      OutputList output_list;
+      output_list.expected_num_outputs = 0;
+      output_list.outputs.reserve(outputs.size());
+      for (int i = 0; i < model_outputs.size(); i++) {
+        if (model_outputs[i]) {
+          output_list.outputs.emplace_back(model_outputs[i]);
+          output_list.expected_num_outputs += 1;
+        } else {
+          null_indices.insert(i);
+        }
+      }
+      TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
+                             ->Finalize(&output_list, &func));
+      scoped_func.reset(func);
+      for (auto output : output_list.outputs) {
+        output->Unref();
+      }
+      TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
+    }
+
+    AbstractOperationPtr fn_op(ctx->CreateOperation());
+    TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
+    for (auto input : inputs) {
+      TF_RETURN_IF_ERROR(fn_op->AddInput(input));
+    }
+    int retvals = outputs.size() - null_indices.size();
+    vector<AbstractTensorHandle*> fn_outputs(retvals);
+    TF_RETURN_IF_ERROR(fn_op->Execute(
+        absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
+        &retvals));
+    int skipped_indices = 0;
+    for (int i = 0; i < outputs.size(); i++) {
+      if (!null_indices.contains(i)) {
+        outputs[i] = fn_outputs[i - skipped_indices];
+      } else {
+        skipped_indices += 1;
+      }
+    }
+    TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
+    return Status::OK();
+  } else {
+    return model(ctx, inputs, outputs, registry);
+  }
+}
+
+Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetTfrt(opts, use_tfrt);
+  *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
+  TFE_DeleteContextOptions(opts);
+  return Status::OK();
+}
+
+}  // namespace internal
+}  // namespace gradients
+}  // namespace tensorflow
diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h
new file mode 100644
index 0000000..efe196e
--- /dev/null
+++ b/tensorflow/c/eager/mnist_gradients_testutil.h
@@ -0,0 +1,150 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+
+#include "absl/types/span.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/eager/c_api_unified_experimental.h"
+#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
+#include "tensorflow/c/eager/gradients.h"
+#include "tensorflow/c/eager/gradients_internal.h"
+#include "tensorflow/c/experimental/ops/array_ops.h"
+#include "tensorflow/c/experimental/ops/math_ops.h"
+#include "tensorflow/c/experimental/ops/nn_ops.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/c/tf_tensor.h"
+#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
+#include "tensorflow/core/platform/status.h"
+
+// ========================== Tape Ops ==============================
+
+namespace tensorflow {
+namespace gradients {
+namespace internal {
+// Computes `inputs[0] + inputs[1]` and records it on the tape.
+Status Add(AbstractContext* ctx, Tape* tape,
+           absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs,
+           const GradientRegistry& registry);
+
+// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
+Status MatMul(AbstractContext* ctx, Tape* tape,
+              absl::Span<AbstractTensorHandle* const> inputs,
+              absl::Span<AbstractTensorHandle*> outputs, const char* name,
+              bool transpose_a, bool transpose_b,
+              const GradientRegistry& registry);
+
+// Computes `inputs[0] * inputs[1]` and records it on the tape.
+Status Mul(AbstractContext* ctx, Tape* tape,
+           absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs, const char* name,
+           const GradientRegistry& registry);
+
+// Computes `Relu(inputs[0])` and records it on the tape.
+Status Relu(AbstractContext* ctx, Tape* tape,
+            absl::Span<AbstractTensorHandle* const> inputs,
+            absl::Span<AbstractTensorHandle*> outputs, const char* name,
+            const GradientRegistry& registry);
+
+// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
+// tape.
+Status SparseSoftmaxCrossEntropyLoss(
+    AbstractContext* ctx, Tape* tape,
+    absl::Span<AbstractTensorHandle* const> inputs,
+    absl::Span<AbstractTensorHandle*> outputs, const char* name,
+    const GradientRegistry& registry);
+
+// ====================== End Tape Ops ============================
+
+// Computes
+// y = inputs[0] + inputs[1]
+// return grad(y, {inputs[0], inputs[1]})
+Status AddGradModel(AbstractContext* ctx,
+                    absl::Span<AbstractTensorHandle* const> inputs,
+                    absl::Span<AbstractTensorHandle*> outputs,
+                    const GradientRegistry& registry);
+
+// Computes
+// y = inputs[0] * inputs[1]
+// return grad(y, {inputs[0], inputs[1]})
+Status MatMulGradModel(AbstractContext* ctx,
+                       absl::Span<AbstractTensorHandle* const> inputs,
+                       absl::Span<AbstractTensorHandle*> outputs,
+                       const GradientRegistry& registry);
+
+// Computes 2-layer Neural Network with Softmax Loss.
+Status MNISTForwardModel(AbstractContext* ctx,
+                         absl::Span<AbstractTensorHandle* const> inputs,
+                         absl::Span<AbstractTensorHandle*> outputs,
+                         const GradientRegistry& registry);
+
+// Computes MatMul with first matrix tranposed.
+Status MatMulTransposeModel(AbstractContext* ctx,
+                            absl::Span<AbstractTensorHandle* const> inputs,
+                            absl::Span<AbstractTensorHandle*> outputs,
+                            const GradientRegistry& registry);
+
+// Test Model to verify ReluGrad functionality
+Status ReluGradModel(AbstractContext* ctx,
+                     absl::Span<AbstractTensorHandle* const> inputs,
+                     absl::Span<AbstractTensorHandle*> outputs,
+                     const GradientRegistry& registry);
+
+// Test Model to verify SoftmaxGrad functionality
+Status SoftmaxLossGradModel(AbstractContext* ctx,
+                            absl::Span<AbstractTensorHandle* const> inputs,
+                            absl::Span<AbstractTensorHandle*> outputs,
+                            const GradientRegistry& registry);
+
+// Test Model to verify Multi-grad functionality for MNIST
+Status MNISTGradModel(AbstractContext* ctx,
+                      absl::Span<AbstractTensorHandle* const> inputs,
+                      absl::Span<AbstractTensorHandle*> outputs,
+                      const GradientRegistry& registry);
+
+// Test Model to verify scalar-tensor multiplication Op
+Status ScalarMulModel(AbstractContext* ctx,
+                      absl::Span<AbstractTensorHandle* const> inputs,
+                      absl::Span<AbstractTensorHandle*> outputs,
+                      const GradientRegistry& registry);
+
+// Updates the weights for a neural network given incoming grads and learning
+// rate
+Status UpdateWeights(AbstractContext* ctx,
+                     std::vector<AbstractTensorHandle*>& grads,
+                     std::vector<AbstractTensorHandle*>& weights,
+                     AbstractTensorHandle* learning_rate);
+
+AbstractContext* BuildFunction(const char* fn_name);
+
+Status CreateParamsForInputs(AbstractContext* ctx,
+                             absl::Span<AbstractTensorHandle* const> inputs,
+                             std::vector<AbstractTensorHandle*>* params);
+
+using Model = std::function<Status(
+    AbstractContext*, absl::Span<AbstractTensorHandle* const>,
+    absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
+
+Status RunModel(Model model, AbstractContext* ctx,
+                absl::Span<AbstractTensorHandle* const> inputs,
+                absl::Span<AbstractTensorHandle*> outputs, bool use_function,
+                const GradientRegistry& registry);
+
+Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
+
+}  // namespace internal
+}  // namespace gradients
+}  // namespace tensorflow
diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
index 678d1a7..df5504a 100644
--- a/tensorflow/c/eager/parallel_device/BUILD
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -76,6 +76,7 @@
         "//tensorflow/c/eager:c_api_experimental",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
         "@com_google_absl//absl/types:variant",
     ],
 )
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc
index d0e9f35..b9d7be7 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device.cc
@@ -255,28 +255,44 @@
 // Since this function is used to satisfy the TFE_CustomDevice C API,
 // device_info is passed in using a C-style generic. It must always be a
 // ParallelDevice.
-void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
-                           TFE_TensorHandle** inputs,
-                           const char* operation_name,
-                           const TFE_OpAttrs* attributes, int* num_outputs,
+void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
                            TFE_TensorHandle** outputs, TF_Status* status,
                            void* device_info) {
+  const char* requested_placement = TFE_OpGetDevice(original_op, status);
+  if (*requested_placement == '\0') {
+    TF_SetStatus(
+        status, TF_INTERNAL,
+        "Ops must be placed on the parallel device explicitly, or their inputs "
+        "first un-packed. Got an un-placed op with an input placed on the "
+        "parallel device.");
+    return;
+  }
+  TFE_Context* context = TFE_OpGetContext(original_op, status);
+  if (TF_GetCode(status) != TF_OK) return;
+  const char* operation_name = TFE_OpGetName(original_op, status);
+  if (TF_GetCode(status) != TF_OK) return;
+  const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
+
   NamedParallelDevice* named_device =
       reinterpret_cast<NamedParallelDevice*>(device_info);
   std::vector<MaybeParallelTensorUnowned> typed_inputs;
+  int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
+  if (TF_GetCode(status) != TF_OK) return;
   typed_inputs.reserve(num_inputs);
   for (int i = 0; i < num_inputs; ++i) {
+    TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
+    if (TF_GetCode(status) != TF_OK) return;
     const char* tensor_handle_device =
-        TFE_TensorHandleDeviceName(inputs[i], status);
+        TFE_TensorHandleDeviceName(input, status);
     if (TF_GetCode(status) != TF_OK) return;
     if (named_device->name() == tensor_handle_device) {
       // We assume that any tensors already placed on this device are
       // ParallelTensors.
       typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
-          TFE_TensorHandleDevicePointer(inputs[i], status)));
+          TFE_TensorHandleDevicePointer(input, status)));
       if (TF_GetCode(status) != TF_OK) return;
     } else {
-      typed_inputs.emplace_back(inputs[i]);
+      typed_inputs.emplace_back(input);
     }
   }
 
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
index 1b707fe..e270bfc 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
 
+#include "tensorflow/c/tf_status.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -261,18 +262,27 @@
                                            status);
 }
 
-std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
-    TFE_Context* context, TF_Status* status) const {
+std::unique_ptr<ParallelTensor> ParallelDevice::Vector(
+    TFE_Context* context, TF_Status* status,
+    absl::Span<const int32_t> values) const {
   // TODO(allenl): We could cache DeviceIDs (keyed by context).
   std::vector<TensorHandlePtr> components;
   components.reserve(underlying_devices_.size());
-  for (int device_index = 0; device_index < underlying_devices_.size();
+
+  if (values.size() != num_underlying_devices()) {
+    TF_SetStatus(
+        status, TF_INVALID_ARGUMENT,
+        "Number of values did not match number of underlying devices.");
+    return nullptr;
+  }
+
+  for (int device_index = 0; device_index < num_underlying_devices();
        ++device_index) {
-    int32_t* device_id = new int32_t;
-    *device_id = device_index;
+    int32_t* device_value = new int32_t;
+    *device_value = values[device_index];
     std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
         TF_NewTensor(
-            TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
+            TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value,
             sizeof(int32_t),
             [](void* data, size_t, void* arg) {
               delete reinterpret_cast<int32_t*>(data);
@@ -301,6 +311,16 @@
                                            status);
 }
 
+std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
+    TFE_Context* context, TF_Status* status) const {
+  std::vector<int32_t> ids;
+  ids.reserve(num_underlying_devices());
+  for (int i = 0; i < num_underlying_devices(); ++i) {
+    ids.push_back(i);
+  }
+  return Vector(context, status, ids);
+}
+
 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
 ParallelDevice::Execute(TFE_Context* context,
                         const std::vector<ParallelTensor*>& inputs,
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
index cbfea31..b3dc47a 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "absl/types/optional.h"
+#include "absl/types/span.h"
 #include "absl/types/variant.h"
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/c/eager/c_api.h"
@@ -61,6 +62,11 @@
                                                        TFE_TensorHandle* tensor,
                                                        TF_Status* status) const;
 
+  // Construct a parallel tensor consisting of the scalar values from `values`.
+  std::unique_ptr<ParallelTensor> Vector(
+      TFE_Context* context, TF_Status* status,
+      absl::Span<const int32_t> values) const;
+
   // A parallel tensor with scalar integers numbering component devices.
   std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
                                             TF_Status* status) const;
diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
index 68875d6..3965359 100644
--- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
+++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD
@@ -29,6 +29,7 @@
         ":gcs_helper",
         ":ram_file_block_cache",
         "//tensorflow/c:env",
+        "//tensorflow/c:logging",
         "//tensorflow/c:tf_status",
         "//tensorflow/c/experimental/filesystem:filesystem_interface",
         "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc
index e01af91..5285989 100644
--- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc
+++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc
@@ -23,6 +23,7 @@
 #include "google/cloud/storage/client.h"
 #include "tensorflow/c/env.h"
 #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
+#include "tensorflow/c/logging.h"
 #include "tensorflow/c/tf_status.h"
 
 // Implementation of a filesystem for GCS environments.
@@ -134,6 +135,8 @@
   }
   // `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
   TF_SetStatus(status, TF_OK, "");
+  TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset,
+          read);
   stream.read(buffer, read);
   read = stream.gcount();
   if (read < buffer_size) {
@@ -146,6 +149,8 @@
                                   path, " @ ", offset)
                          .c_str());
       }
+      TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(),
+              offset);
     }
   }
   return read;
@@ -284,6 +289,8 @@
       TF_SetStatusFromGCSStatus(metadata.status(), status);
       return;
     }
+    TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(),
+            temporary_object.c_str(), bucket.c_str(), object.c_str());
     const std::vector<gcs::ComposeSourceObject> source_objects = {
         {object, {}, {}}, {temporary_object, {}, {}}};
     metadata = gcs_client->ComposeObject(bucket, source_objects, object);
@@ -321,6 +328,8 @@
                  "The internal temporary file is not writable.");
     return;
   }
+  TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(),
+          gcs_file->object.c_str(), n);
   gcs_file->sync_need = true;
   gcs_file->outfile.write(buffer, n);
   if (!gcs_file->outfile)
@@ -346,6 +355,8 @@
 void Flush(const TF_WritableFile* file, TF_Status* status) {
   auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
   if (gcs_file->sync_need) {
+    TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(),
+            gcs_file->object.c_str());
     if (!gcs_file->outfile) {
       TF_SetStatus(status, TF_INTERNAL,
                    "Could not append to the internal temporary file.");
@@ -353,6 +364,8 @@
     }
     SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
              &gcs_file->outfile, gcs_file->gcs_client, status);
+    TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(),
+            gcs_file->object.c_str());
     if (TF_GetCode(status) != TF_OK) return;
     gcs_file->sync_need = false;
   } else {
@@ -361,11 +374,16 @@
 }
 
 void Sync(const TF_WritableFile* file, TF_Status* status) {
+  auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
+  TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(),
+          gcs_file->object.c_str());
   Flush(file, status);
 }
 
 void Close(const TF_WritableFile* file, TF_Status* status) {
   auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
+  TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(),
+          gcs_file->object.c_str());
   if (gcs_file->sync_need) {
     Flush(file, status);
   }
@@ -428,6 +446,8 @@
   if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
     max_staleness = value;
   }
+  TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u",
+          max_bytes, block_size, max_staleness);
 
   file_block_cache = std::make_unique<RamFileBlockCache>(
       block_size, max_bytes, max_staleness,
@@ -511,6 +531,10 @@
   stat->base.mtime_nsec =
       metadata->time_storage_class_updated().time_since_epoch().count();
   stat->base.is_directory = object.back() == '/';
+  TF_VLog(1,
+          "Stat of: gs://%s/%s --  length: %u generation: %u; mtime_nsec: %u;",
+          bucket.c_str(), object.c_str(), stat->base.length,
+          stat->generation_number, stat->base.mtime_nsec);
   return TF_SetStatus(status, TF_OK, "");
 }
 
@@ -545,9 +569,10 @@
       if (TF_GetCode(status) != TF_OK) return -1;
       if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
               path, stat.generation_number)) {
-        std::cout
-            << "File signature has been changed. Refreshing the cache. Path: "
-            << path;
+        TF_VLog(
+            1,
+            "File signature has been changed. Refreshing the cache. Path: %s",
+            path.c_str());
       }
       read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
     } else {
@@ -579,6 +604,7 @@
        (gcs_file->compose ? 0 : -1)});
   // We are responsible for freeing the pointer returned by TF_GetTempFileName
   free(temp_file_name);
+  TF_VLog(3, "GcsWritableFile: %s", path);
   TF_SetStatus(status, TF_OK, "");
 }
 
@@ -624,7 +650,8 @@
       return;
     }
   }
-
+  TF_VLog(3, "GcsWritableFile: %s with existing file %s", path,
+          temp_file_name.c_str());
   TF_SetStatus(status, TF_OK, "");
 }
 
@@ -812,6 +839,10 @@
                TF_Status* status) {
   std::string dir = path;
   MaybeAppendSlash(&dir);
+  TF_VLog(3,
+          "CreateDir: creating directory with path: %s and "
+          "path_with_slash: %s",
+          path, dir.c_str());
   std::string bucket, object;
   ParseGCSPath(dir, true, &bucket, &object, status);
   if (TF_GetCode(status) != TF_OK) return;
@@ -826,8 +857,11 @@
   }
 
   PathExists(filesystem, dir.c_str(), status);
-  if (TF_GetCode(status) == TF_OK)
+  if (TF_GetCode(status) == TF_OK) {
+    // Use the original name for a correct error here.
+    TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path);
     return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
+  }
 
   auto metadata = gcs_file->gcs_client.InsertObject(
       bucket, object, "",
@@ -933,6 +967,7 @@
 static void RenameObject(const TF_Filesystem* filesystem,
                          const std::string& src, const std::string& dst,
                          TF_Status* status) {
+  TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str());
   std::string bucket_src, object_src;
   ParseGCSPath(src, false, &bucket_src, &object_src, status);
   if (TF_GetCode(status) != TF_OK) return;
@@ -946,6 +981,7 @@
       bucket_src, object_src, bucket_dst, object_dst);
   TF_SetStatusFromGCSStatus(metadata.status(), status);
   if (TF_GetCode(status) != TF_OK) return;
+  TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str());
 
   ClearFileCaches(gcs_file, dst);
   DeleteFile(filesystem, src.c_str(), status);
diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD
index 51ffd70..281acd2 100644
--- a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD
+++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD
@@ -1,5 +1,5 @@
 # Experimental hadoop filesystem plugin.
-load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
+load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
 
 package(
     licenses = ["notice"],  # Apache 2.0
@@ -33,3 +33,20 @@
         "@com_google_absl//absl/synchronization",
     ],
 )
+
+tf_cc_test(
+    name = "hadoop_filesystem_test",
+    srcs = [
+        "hadoop_filesystem_test.cc",
+    ],
+    tags = [
+        "manual",
+        "notap",
+    ],
+    deps = [
+        ":hadoop_filesystem_impl",
+        "//tensorflow/core/platform:path",
+        "//tensorflow/core/platform:stacktrace_handler",
+        "//tensorflow/core/platform:test",
+    ],
+)
diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc
index e53e3d0..093bc4d 100644
--- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc
+++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc
@@ -308,7 +308,7 @@
         handle(handle) {}
 } HDFSFile;
 
-static void Cleanup(TF_WritableFile* file) {
+void Cleanup(TF_WritableFile* file) {
   auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
   hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
   hdfs_file->fs = nullptr;
diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h
index 850cefe..0df45c5 100644
--- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h
+++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h
@@ -15,7 +15,58 @@
 #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
 #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
 
+#include <string>
+
 #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
 #include "tensorflow/c/tf_status.h"
 
+void ParseHadoopPath(const std::string& fname, std::string* scheme,
+                     std::string* namenode, std::string* path);
+class LibHDFS;
+
+namespace tf_random_access_file {
+void Cleanup(TF_RandomAccessFile* file);
+int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
+             char* buffer, TF_Status* status);
+}  // namespace tf_random_access_file
+
+namespace tf_writable_file {
+void Cleanup(TF_WritableFile* file);
+void Append(const TF_WritableFile* file, const char* buffer, size_t n,
+            TF_Status* status);
+int64_t Tell(const TF_WritableFile* file, TF_Status* status);
+void Sync(const TF_WritableFile* file, TF_Status* status);
+void Flush(const TF_WritableFile* file, TF_Status* status);
+void Close(const TF_WritableFile* file, TF_Status* status);
+}  // namespace tf_writable_file
+
+namespace tf_hadoop_filesystem {
+void Init(TF_Filesystem* filesystem, TF_Status* status);
+void Cleanup(TF_Filesystem* filesystem);
+void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
+                         TF_RandomAccessFile* file, TF_Status* status);
+void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
+                     TF_WritableFile* file, TF_Status* status);
+void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
+                                     const char* path,
+                                     TF_ReadOnlyMemoryRegion* region,
+                                     TF_Status* status);
+void PathExists(const TF_Filesystem* filesystem, const char* path,
+                TF_Status* status);
+void Stat(const TF_Filesystem* filesystem, const char* path,
+          TF_FileStatistics* stats, TF_Status* status);
+int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
+                    TF_Status* status);
+void DeleteFile(const TF_Filesystem* filesystem, const char* path,
+                TF_Status* status);
+void CreateDir(const TF_Filesystem* filesystem, const char* path,
+               TF_Status* status);
+void DeleteDir(const TF_Filesystem* filesystem, const char* path,
+               TF_Status* status);
+void RenameFile(const TF_Filesystem* filesystem, const char* src,
+                const char* dst, TF_Status* status);
+int GetChildren(const TF_Filesystem* filesystem, const char* path,
+                char*** entries, TF_Status* status);
+}  // namespace tf_hadoop_filesystem
+
 #endif  // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc
new file mode 100644
index 0000000..57888c1
--- /dev/null
+++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h"
+
+#include "tensorflow/core/platform/path.h"
+#include "tensorflow/core/platform/stacktrace_handler.h"
+#include "tensorflow/core/platform/test.h"
+
+#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
+#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
+
+namespace tensorflow {
+namespace {
+
+class HadoopFileSystemTest : public ::testing::Test {
+ public:
+  void SetUp() override {
+    status_ = TF_NewStatus();
+    filesystem_ = new TF_Filesystem;
+    tf_hadoop_filesystem::Init(filesystem_, status_);
+    ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
+                          << TF_Message(status_);
+  }
+  void TearDown() override {
+    TF_DeleteStatus(status_);
+    tf_hadoop_filesystem::Cleanup(filesystem_);
+    delete filesystem_;
+  }
+
+  std::string TmpDir(const std::string& path) {
+    char* test_dir = getenv("HADOOP_TEST_TMPDIR");
+    if (test_dir != nullptr) {
+      return io::JoinPath(std::string(test_dir), path);
+    } else {
+      return "file://" + io::JoinPath(testing::TmpDir(), path);
+    }
+  }
+
+  std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)>
+  GetWriter() {
+    std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> writer(
+        new TF_WritableFile, [](TF_WritableFile* file) {
+          if (file != nullptr) {
+            if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
+            delete file;
+          }
+        });
+    writer->plugin_file = nullptr;
+    return writer;
+  }
+
+  std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile* file)>
+  GetReader() {
+    std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile * file)>
+        reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
+          if (file != nullptr) {
+            if (file->plugin_file != nullptr)
+              tf_random_access_file::Cleanup(file);
+            delete file;
+          }
+        });
+    reader->plugin_file = nullptr;
+    return reader;
+  }
+
+ protected:
+  TF_Filesystem* filesystem_;
+  TF_Status* status_;
+};
+
+TEST_F(HadoopFileSystemTest, Init) { ASSERT_TF_OK(status_); }
+
+}  // namespace
+}  // namespace tensorflow
+
+GTEST_API_ int main(int argc, char** argv) {
+  tensorflow::testing::InstallStacktraceHandler();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD
index 9e7dc30..36a3251 100644
--- a/tensorflow/c/experimental/gradients/BUILD
+++ b/tensorflow/c/experimental/gradients/BUILD
@@ -37,6 +37,28 @@
         "//tensorflow/c/eager:gradients",
         "//tensorflow/c/experimental/ops:array_ops",
         "//tensorflow/c/experimental/ops:math_ops",
+        "//tensorflow/c/experimental/ops:nn_ops",
+        "//tensorflow/core/lib/llvm_rtti",
+    ],
+)
+
+cc_library(
+    name = "nn_grad",
+    srcs = ["nn_grad.cc"],
+    hdrs = [
+        "nn_grad.h",
+    ],
+    visibility = [
+        "//tensorflow:internal",
+    ],
+    deps = [
+        "//tensorflow/c/eager:abstract_operation",
+        "//tensorflow/c/eager:abstract_tensor_handle",
+        "//tensorflow/c/eager:c_api_unified_internal",
+        "//tensorflow/c/eager:gradients",
+        "//tensorflow/c/experimental/ops:array_ops",
+        "//tensorflow/c/experimental/ops:math_ops",
+        "//tensorflow/c/experimental/ops:nn_ops",
         "//tensorflow/core/lib/llvm_rtti",
     ],
 )
diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc
index cfe122b..3537b30 100644
--- a/tensorflow/c/experimental/gradients/math_grad.cc
+++ b/tensorflow/c/experimental/gradients/math_grad.cc
@@ -18,11 +18,14 @@
 #include "tensorflow/c/eager/gradients.h"
 #include "tensorflow/c/experimental/ops/array_ops.h"
 #include "tensorflow/c/experimental/ops/math_ops.h"
+#include "tensorflow/c/experimental/ops/nn_ops.h"
 
 using std::vector;
 using tensorflow::ops::Conj;
 using tensorflow::ops::Identity;
+using tensorflow::ops::MatMul;
 using tensorflow::ops::Mul;
+using tensorflow::ops::ZerosLike;
 
 namespace tensorflow {
 namespace gradients {
@@ -36,13 +39,17 @@
     vector<AbstractTensorHandle*> identity_outputs(1);
     // TODO(b/145674566): Handle name unification in tracing code.
     // TODO(b/161805092): Support broadcasting.
+
+    std::string name = "Identity_A";
     TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
                                      absl::MakeSpan(identity_outputs),
-                                     "Identity0"));
+                                     name.c_str()));
     (*grad_outputs)[0] = identity_outputs[0];
+
+    name = "Identity_B";
     TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
                                      absl::MakeSpan(identity_outputs),
-                                     "Identity1"));
+                                     name.c_str()));
     (*grad_outputs)[1] = identity_outputs[0];
     return Status::OK();
   }
@@ -57,12 +64,15 @@
   Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
                  vector<AbstractTensorHandle*>* grad_outputs) override {
     vector<AbstractTensorHandle*> conj_outputs(1);
-    TF_RETURN_IF_ERROR(
-        Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj"));
+    std::string name = "Conj_Exp_Grad";
+    TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()},
+                            absl::MakeSpan(conj_outputs), name.c_str()));
     AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
     grad_outputs->resize(1);
+
+    name = "Mul_Exp_Grad";
     TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
-                           absl::MakeSpan(*grad_outputs), "ExpGradMul"));
+                           absl::MakeSpan(*grad_outputs), name.c_str()));
     return Status::OK();
   }
   ~ExpGradientFunction() override {}
@@ -71,6 +81,115 @@
   AbstractTensorHandlePtr exp_;
 };
 
+class MatMulGradientFunction : public GradientFunction {
+ public:
+  explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
+                                  AttrBuilder f_attrs)
+      : forward_inputs(f_inputs), forward_attrs(f_attrs) {}
+
+  Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
+                 vector<AbstractTensorHandle*>* grad_outputs) override {
+    /* Given upstream grad U and a matmul op A*B, the gradients are:
+     *
+     *    dA = U * B.T
+     *    dB = A.T * U
+     *
+     *    where A.T means `transpose(A)`
+     */
+    AbstractTensorHandle* upstream_grad = grad_inputs[0];
+    grad_outputs->resize(2);
+
+    // Get transpose attrs
+    bool t_a;
+    TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_a", &t_a));
+
+    bool t_b;
+    TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_b", &t_b));
+
+    // Conj each input
+    vector<AbstractTensorHandle*> conj_outputs(1);
+    std::string name = "Conj_A_MatMul_Grad";
+    TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]},
+                            absl::MakeSpan(conj_outputs), name.c_str()));
+
+    AbstractTensorHandle* A = conj_outputs[0];
+
+    name = "Conj_B_MatMul_Grad";
+    TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]},
+                            absl::MakeSpan(conj_outputs), name.c_str()));
+
+    AbstractTensorHandle* B = conj_outputs[0];
+
+    // Calc Grad
+    vector<AbstractTensorHandle*> matmul_A_outputs(1);
+    vector<AbstractTensorHandle*> matmul_B_outputs(1);
+    std::string name_grad_A = "MatMul_Grad_A";
+    std::string name_grad_B = "MatMul_Grad_B";
+    if (!t_a && !t_b) {
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
+                                absl::MakeSpan(matmul_A_outputs),
+                                name_grad_A.c_str(),
+                                /*transpose_a = */ false,
+                                /*transpose_b = */ true));
+
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
+                                absl::MakeSpan(matmul_B_outputs),
+                                name_grad_B.c_str(),
+                                /*transpose_a = */ true,
+                                /*transpose_b = */ false));
+    } else if (!t_a && t_b) {
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
+                                absl::MakeSpan(matmul_A_outputs),
+                                name_grad_A.c_str(),
+                                /*transpose_a = */ false,
+                                /*transpose_b = */ false));
+
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
+                                absl::MakeSpan(matmul_B_outputs),
+                                name_grad_B.c_str(),
+                                /*transpose_a = */ true,
+                                /*transpose_b = */ false));
+
+    } else if (t_a && !t_b) {
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
+                                absl::MakeSpan(matmul_A_outputs),
+                                name_grad_A.c_str(),
+                                /*transpose_a = */ false,
+                                /*transpose_b = */ true));
+
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
+                                absl::MakeSpan(matmul_B_outputs),
+                                name_grad_B.c_str(),
+                                /*transpose_a = */ false,
+                                /*transpose_b = */ false));
+    } else {  // t_a && t_b
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
+                                absl::MakeSpan(matmul_A_outputs),
+                                name_grad_A.c_str(),
+                                /*transpose_a = */ true,
+                                /*transpose_b = */ true));
+
+      TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
+                                absl::MakeSpan(matmul_B_outputs),
+                                name_grad_B.c_str(),
+                                /*transpose_a = */ true,
+                                /*transpose_b = */ true));
+    }
+
+    // Gradient for A
+    (*grad_outputs)[0] = matmul_A_outputs[0];
+
+    // Gradient for B
+    (*grad_outputs)[1] = matmul_B_outputs[0];
+    return Status::OK();
+  }
+  ~MatMulGradientFunction() override {}
+
+ private:
+  vector<AbstractTensorHandle*> forward_inputs;
+  AttrBuilder forward_attrs;
+};
+
 }  // namespace
 
 BackwardFunction* AddRegisterer(const ForwardOperation& op) {
@@ -91,5 +210,14 @@
   return new BackwardFunction(gradient_function, default_gradients);
 }
 
+BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
+  auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs);
+  // For ops with a single output, the gradient function is not called if there
+  // is no incoming gradient. So we do not need to worry about creating zeros
+  // grads in this case.
+  auto default_gradients = new PassThroughDefaultGradients(op);
+  return new BackwardFunction(gradient_function, default_gradients);
+}
+
 }  // namespace gradients
 }  // namespace tensorflow
diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h
index 7348ef3..205419e 100644
--- a/tensorflow/c/experimental/gradients/math_grad.h
+++ b/tensorflow/c/experimental/gradients/math_grad.h
@@ -21,7 +21,8 @@
 namespace gradients {
 BackwardFunction* AddRegisterer(const ForwardOperation& op);
 BackwardFunction* ExpRegisterer(const ForwardOperation& op);
+BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
 }  // namespace gradients
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
+#endif  // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
\ No newline at end of file
diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc
new file mode 100644
index 0000000..3da1e0d
--- /dev/null
+++ b/tensorflow/c/experimental/gradients/nn_grad.cc
@@ -0,0 +1,111 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/c/experimental/gradients/nn_grad.h"
+
+#include "tensorflow/c/experimental/ops/array_ops.h"
+#include "tensorflow/c/experimental/ops/math_ops.h"
+#include "tensorflow/c/experimental/ops/nn_ops.h"
+
+using std::vector;
+using tensorflow::ops::Conj;
+using tensorflow::ops::Identity;
+using tensorflow::ops::Mul;
+using tensorflow::ops::ReluGrad;
+using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
+using tensorflow::ops::ZerosLike;
+
+namespace tensorflow {
+namespace gradients {
+namespace {
+
+class ReluGradientFunction : public GradientFunction {
+ public:
+  explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
+      : forward_outputs(f_outputs) {}
+
+  Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
+                 vector<AbstractTensorHandle*>* grad_outputs) override {
+    AbstractTensorHandle* upstream_grad = grad_inputs[0];
+    AbstractTensorHandle* activations = forward_outputs[0];
+    grad_outputs->resize(1);
+    vector<AbstractTensorHandle*> relugrad_outputs(1);
+
+    // Calculate Grad
+    std::string name = "relu_grad";
+
+    TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations},
+                                absl::MakeSpan(relugrad_outputs),
+                                name.c_str()));
+    (*grad_outputs)[0] = relugrad_outputs[0];
+
+    return Status::OK();
+  }
+  ~ReluGradientFunction() override {}
+
+ private:
+  vector<AbstractTensorHandle*> forward_outputs;
+};
+
+class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
+ public:
+  explicit SparseSoftmaxCrossEntropyLossGradientFunction(
+      vector<AbstractTensorHandle*> f_outputs)
+      : forward_outputs(f_outputs) {}
+
+  Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
+                 vector<AbstractTensorHandle*>* grad_outputs) override {
+    grad_outputs->resize(2);
+
+    // Grad for Softmax Input
+    std::string name = "Mul_Softmax_Grad";
+    vector<AbstractTensorHandle*> mul_outputs(1);
+    TF_RETURN_IF_ERROR(
+        ops::Mul(ctx->ctx, {grad_inputs[0], forward_outputs[1]},
+                 absl::MakeSpan(mul_outputs),
+                 name.c_str()));  // upstream_grad * local softmax grad
+    (*grad_outputs)[0] = mul_outputs[0];
+
+    // Grad for labels is null
+    (*grad_outputs)[1] = nullptr;
+
+    return Status::OK();
+  }
+  ~SparseSoftmaxCrossEntropyLossGradientFunction() override {}
+
+ private:
+  vector<AbstractTensorHandle*> forward_outputs;
+};
+
+}  // namespace
+
+BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
+  auto gradient_function = new ReluGradientFunction(op.outputs);
+  // For ops with a single output, the gradient function is not called if there
+  // is no incoming gradient. So we do not need to worry about creating zeros
+  // grads in this case.
+  auto default_gradients = new PassThroughDefaultGradients(op);
+  return new BackwardFunction(gradient_function, default_gradients);
+}
+
+BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
+    const ForwardOperation& op) {
+  auto gradient_function =
+      new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs);
+  auto default_gradients = new PassThroughDefaultGradients(op);
+  return new BackwardFunction(gradient_function, default_gradients);
+}
+
+}  // namespace gradients
+}  // namespace tensorflow
diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h
new file mode 100644
index 0000000..d002725
--- /dev/null
+++ b/tensorflow/c/experimental/gradients/nn_grad.h
@@ -0,0 +1,28 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
+#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
+
+#include "tensorflow/c/eager/gradients.h"
+
+namespace tensorflow {
+namespace gradients {
+BackwardFunction* ReluRegisterer(const ForwardOperation& op);
+BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
+    const ForwardOperation& op);
+}  // namespace gradients
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
\ No newline at end of file
diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD
index d13d7a7..c5810bf 100644
--- a/tensorflow/c/experimental/ops/BUILD
+++ b/tensorflow/c/experimental/ops/BUILD
@@ -38,10 +38,29 @@
     deps = [
         ":array_ops",
         "//tensorflow/c/eager:abstract_context",
+        "//tensorflow/c/eager:abstract_tensor_handle",
+        "//tensorflow/c/eager:c_api_unified_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/lib/llvm_rtti",
+        "//tensorflow/core/platform:errors",
+    ],
+)
+
+cc_library(
+    name = "nn_ops",
+    srcs = [
+        "nn_ops.cc",
+    ],
+    hdrs = [
+        "nn_ops.h",
+    ],
+    visibility = [
+        "//tensorflow:internal",
+    ],
+    deps = [
         "//tensorflow/c/eager:abstract_operation",
         "//tensorflow/c/eager:abstract_tensor_handle",
         "//tensorflow/c/eager:c_api_unified_internal",
-        "//tensorflow/core:framework_headers_lib",
         "//tensorflow/core/lib/llvm_rtti",
         "//tensorflow/core/platform:errors",
     ],
diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc
index ab2d114..df0f463 100644
--- a/tensorflow/c/experimental/ops/array_ops.cc
+++ b/tensorflow/c/experimental/ops/array_ops.cc
@@ -19,7 +19,7 @@
 
 namespace tensorflow {
 namespace ops {
-// Creates an Identity op.
+
 Status Identity(AbstractContext* ctx,
                 absl::Span<AbstractTensorHandle* const> inputs,
                 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
@@ -35,5 +35,19 @@
   return identity_op->Execute(outputs, &num_retvals);
 }
 
+Status ZerosLike(AbstractContext* ctx,
+                 absl::Span<AbstractTensorHandle* const> inputs,
+                 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
+  AbstractOperationPtr z_op(ctx->CreateOperation());
+  TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
+  if (isa<tensorflow::tracing::TracingOperation>(z_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<tracing::TracingOperation>(z_op.get())->SetOpName(name));
+  }
+  TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0]));
+  int num_retvals = 1;
+  return z_op->Execute(outputs, &num_retvals);
+}
+
 }  // namespace ops
 }  // namespace tensorflow
diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h
index 226461f..8dc68db 100644
--- a/tensorflow/c/experimental/ops/array_ops.h
+++ b/tensorflow/c/experimental/ops/array_ops.h
@@ -22,9 +22,15 @@
 
 namespace tensorflow {
 namespace ops {
+
 Status Identity(AbstractContext* ctx,
                 absl::Span<AbstractTensorHandle* const> inputs,
                 absl::Span<AbstractTensorHandle*> outputs, const char* name);
+
+Status ZerosLike(AbstractContext* ctx,
+                 absl::Span<AbstractTensorHandle* const> inputs,
+                 absl::Span<AbstractTensorHandle*> outputs, const char* name);
+
 }  // namespace ops
 }  // namespace tensorflow
 
diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc
index e91acbd..82c2f0e 100644
--- a/tensorflow/c/experimental/ops/math_ops.cc
+++ b/tensorflow/c/experimental/ops/math_ops.cc
@@ -51,5 +51,60 @@
   return Status::OK();
 }
 
+Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs, const char* name) {
+  AbstractOperationPtr add_op(ctx->CreateOperation());
+  TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
+
+  if (isa<tracing::TracingOperation>(add_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
+  }
+
+  TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
+  TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
+
+  int num_retvals = 1;
+  TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals));
+  return Status::OK();
+}
+
+Status MatMul(AbstractContext* ctx,
+              absl::Span<AbstractTensorHandle* const> inputs,
+              absl::Span<AbstractTensorHandle*> outputs, const char* name,
+              bool transpose_a = false, bool transpose_b = false) {
+  AbstractOperationPtr matmul_op(ctx->CreateOperation());
+  TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
+
+  if (isa<tracing::TracingOperation>(matmul_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
+  }
+
+  TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
+  TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
+
+  TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a));
+  TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b));
+
+  int num_retvals = 1;
+  TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals));
+  return Status::OK();
+}
+
+Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs, const char* name) {
+  AbstractOperationPtr neg_op(ctx->CreateOperation());
+  TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
+  if (isa<TracingOperation>(neg_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<TracingOperation>(neg_op.get())->SetOpName(name));
+  }
+  TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
+
+  int num_retvals = 1;
+  return neg_op->Execute(outputs, &num_retvals);
+}
+
 }  // namespace ops
 }  // namespace tensorflow
diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h
index 4d7c3d8..ed1e6c5 100644
--- a/tensorflow/c/experimental/ops/math_ops.h
+++ b/tensorflow/c/experimental/ops/math_ops.h
@@ -25,6 +25,15 @@
 Status Conj(AbstractContext* ctx,
             absl::Span<AbstractTensorHandle* const> inputs,
             absl::Span<AbstractTensorHandle*> outputs, const char* name);
+Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs, const char* name);
+Status MatMul(AbstractContext* ctx,
+              absl::Span<AbstractTensorHandle* const> inputs,
+              absl::Span<AbstractTensorHandle*> outputs, const char* name,
+              bool transpose_a, bool transpose_b);
+Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+           absl::Span<AbstractTensorHandle*> outputs, const char* name);
+
 }  // namespace ops
 }  // namespace tensorflow
 
diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc
new file mode 100644
index 0000000..8f5f550
--- /dev/null
+++ b/tensorflow/c/experimental/ops/nn_ops.cc
@@ -0,0 +1,67 @@
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/c/experimental/ops/nn_ops.h"
+
+#include "tensorflow/core/platform/errors.h"
+
+namespace tensorflow {
+namespace ops {
+
+// Softmax Loss given scores and labels, used by the SoftMaxLossGradient
+Status SparseSoftmaxCrossEntropyLoss(
+    AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+    absl::Span<AbstractTensorHandle*> outputs, const char* name) {
+  AbstractOperationPtr sm_loss_op(ctx->CreateOperation());
+  TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits",
+                                       /*raw_device_name=*/nullptr));
+
+  if (isa<tracing::TracingOperation>(sm_loss_op.get())) {
+    TF_RETURN_IF_ERROR(
+        dyn_cast<tracing::TracingOperation>(sm_loss_op.get())->SetOpName(name));
+  }
+
+  TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0]));  // input scores
+  TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1]));  // labels
+
+  // Outputs will contain: [loss_vals, gradients].
+  int num_retvals = 2;
+  TF_RETURN_IF_ERROR(sm_loss_op->Execute(outputs, &num_retvals));
+  return Status::OK();
+}
+
+// Computes Relu gradient given input features
+Status ReluGrad(AbstractContext* ctx,
+                absl::Span<AbstractTensorHandle* const> inputs,
+                absl::Span<AbstractTensorHandle*> outputs, const char* name) {
+  AbstractOperationPtr relugrad_op(ctx->CreateOperation());
+  TF_RETURN_IF_ERROR(
+      relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr));
+
+  if (isa<tracing::TracingOperation>(relugrad_op.get())) {
+    TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(relugrad_op.get())
+                           ->SetOpName(name));
+  }
+
+  TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0]));  // upstream grads
+  TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1]));  // relu inputs
+
+  int num_retvals = 1;
+  TF_RETURN_IF_ERROR(relugrad_op->Execute(outputs, &num_retvals));
+  return Status::OK();
+}
+
+}  // namespace ops
+}  // namespace tensorflow
diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h
new file mode 100644
index 0000000..3e618b0
--- /dev/null
+++ b/tensorflow/c/experimental/ops/nn_ops.h
@@ -0,0 +1,37 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
+#define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
+
+#include "tensorflow/c/eager/abstract_operation.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
+#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
+
+namespace tensorflow {
+namespace ops {
+
+Status SparseSoftmaxCrossEntropyLoss(
+    AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
+    absl::Span<AbstractTensorHandle*> outputs, const char* name);
+
+Status ReluGrad(AbstractContext* ctx,
+                absl::Span<AbstractTensorHandle* const> inputs,
+                absl::Span<AbstractTensorHandle*> outputs, const char* name);
+
+}  // namespace ops
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_
diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD
index 3e0989b..2feb7c1 100644
--- a/tensorflow/c/experimental/saved_model/core/BUILD
+++ b/tensorflow/c/experimental/saved_model/core/BUILD
@@ -229,13 +229,13 @@
         "//tensorflow/c/experimental/saved_model/core/revived_types:constant",
         "//tensorflow/core:all_kernels",
         "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/common_runtime:core_cpu_lib",
         "//tensorflow/core/common_runtime/eager:context",
         "//tensorflow/core/common_runtime/eager:core",
+        "//tensorflow/core/common_runtime/eager:tensor_handle",
     ],
 )
 
diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h
index 934fa6d..48a20ef 100644
--- a/tensorflow/c/experimental/saved_model/core/concrete_function.h
+++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h
@@ -43,8 +43,8 @@
   virtual ~ConcreteFunction() = default;
 
   // This method returns the "Call" Op used to execute the function.
-  virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
-                           ImmediateOpPtr* out) = 0;
+  virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
+                            ImmediateOpPtr* out) const = 0;
 
   virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
 };
diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
index 492a58f..be9ffff 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
+++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
@@ -37,10 +37,11 @@
 
 Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
                                            DataType dtype, TensorShape shape,
+                                           const char* raw_device_name,
                                            ImmediateTensorHandlePtr* handle) {
   ImmediateOpPtr varhandle_op(ctx->CreateOperation());
 
-  TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
+  TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", raw_device_name));
   TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
 
   // Note that if shape is unknown rank, shape.dim_sizes() will be empty, and
diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h
index 13c941a..accad15 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h
+++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h
@@ -31,6 +31,7 @@
 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
 Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
                                            DataType dtype, TensorShape shape,
+                                           const char* raw_device_name,
                                            ImmediateTensorHandlePtr* handle);
 
 // Executes an AssignVariableOp using `ctx`, assigning the variable associated
diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
index 55a4a32..5ce027f 100644
--- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
+++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc
@@ -55,7 +55,7 @@
   // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
   ImmediateTensorHandlePtr handle;
   TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
-      context(), DT_FLOAT, {}, &handle));
+      context(), DT_FLOAT, {}, nullptr, &handle));
   // The created TensorHandle should be a DT_Resource
   EXPECT_EQ(handle->DataType(), DT_RESOURCE);
 }
@@ -65,7 +65,7 @@
   // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
   ImmediateTensorHandlePtr handle;
   TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
-      context(), DT_FLOAT, {}, &handle));
+      context(), DT_FLOAT, {}, nullptr, &handle));
 
   // Destroy the variable
   TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
@@ -76,7 +76,7 @@
   // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
   ImmediateTensorHandlePtr variable;
   TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
-      context(), DT_FLOAT, {}, &variable));
+      context(), DT_FLOAT, {}, nullptr, &variable));
 
   // Create a Scalar float TensorHandle with value 42, and assign it to
   // the variable.
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
index 2b88361..25cac39 100644
--- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD
@@ -29,6 +29,26 @@
 )
 
 cc_library(
+    name = "flat_tensor_function",
+    srcs = [
+        "flat_tensor_function.cc",
+    ],
+    hdrs = [
+        "flat_tensor_function.h",
+    ],
+    deps = [
+        "//tensorflow/c/eager:abstract_tensor_handle",
+        "//tensorflow/c/eager:immediate_execution_context",
+        "//tensorflow/c/eager:immediate_execution_operation",
+        "//tensorflow/c/eager:immediate_execution_tensor_handle",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/common_runtime/eager:context",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+cc_library(
     name = "variable",
     srcs = [
         "variable.cc",
@@ -68,7 +88,7 @@
         "tf_concrete_function.h",
     ],
     deps = [
-        ":tensorhandle_convertible",
+        ":flat_tensor_function",
         "//tensorflow/c/eager:abstract_tensor_handle",
         "//tensorflow/c/eager:immediate_execution_context",
         "//tensorflow/c/eager:immediate_execution_operation",
@@ -81,3 +101,26 @@
         "@com_google_absl//absl/types:span",
     ],
 )
+
+cc_library(
+    name = "tf_signature_def_function",
+    srcs = [
+        "tf_signature_def_function.cc",
+    ],
+    hdrs = [
+        "tf_signature_def_function.h",
+    ],
+    deps = [
+        ":flat_tensor_function",
+        "//tensorflow/c/eager:abstract_tensor_handle",
+        "//tensorflow/c/eager:immediate_execution_context",
+        "//tensorflow/c/eager:immediate_execution_operation",
+        "//tensorflow/c/eager:immediate_execution_tensor_handle",
+        "//tensorflow/c/experimental/saved_model/core:signature_def_function",
+        "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/common_runtime/eager:context",
+        "@com_google_absl//absl/types:span",
+    ],
+)
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc
new file mode 100644
index 0000000..ad9f896
--- /dev/null
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.cc
@@ -0,0 +1,85 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/types/span.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/eager/immediate_execution_operation.h"
+#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace tensorflow {
+
+FlatTensorFunction::FlatTensorFunction(
+    const std::string& name,
+    std::vector<ImmediateExecutionTensorHandle*> captures,
+    ImmediateExecutionContext* ctx)
+    : name_(name), captures_(std::move(captures)), ctx_(ctx) {}
+
+FlatTensorFunction::~FlatTensorFunction() {
+  Status status = ctx_->RemoveFunction(name_);
+  if (!status.ok()) {
+    LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
+               << status.error_message();
+  }
+}
+
+Status FlatTensorFunction::Create(
+    const FunctionDef* function_def,
+    std::vector<ImmediateExecutionTensorHandle*> captures,
+    ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
+  TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
+  out->reset(new FlatTensorFunction(function_def->signature().name(),
+                                    std::move(captures), ctx));
+  return Status();
+}
+
+Status FlatTensorFunction::MakeCallOp(
+    absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
+  out->reset(ctx_->CreateOperation());
+  // In eager mode, TF2 python executes functions by constructing an op with
+  // the name of the functiondef:
+  // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
+  // In graph mode, we create a PartitionedCallOp instead:
+  // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
+
+  // TODO(bmzhao): After discussing with Allen, we should execute this via a
+  // PartitionedCallOp for compatibility with "tooling that assumes functions in
+  // graphs are PartitionedCallOps".
+  TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
+
+  // Adding the user-provided inputs to the function.
+  TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
+
+  absl::Span<AbstractTensorHandle* const> captures(
+      reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
+      captures_.size());
+
+  // Adding the captures of the function.
+  TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
+  return Status();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h
new file mode 100644
index 0000000..e6bcdec
--- /dev/null
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h
@@ -0,0 +1,84 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
+#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/c/eager/immediate_execution_context.h"
+#include "tensorflow/c/eager/immediate_execution_operation.h"
+#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
+
+namespace tensorflow {
+
+// FlatTensorFunction models a TF2 eager runtime view of a callable function,
+// taking + returning flat lists of tensors, including any captures.
+// Effectively, it is a thin wrapper around a FunctionDef owned by the
+// EagerContext, and any TensorHandle captures associated with the function. The
+// MakeCallOp method handles the logic of marshaling captures after the user
+// provided inputs automatically.
+// Note(bmzhao): This class is mainly intended to house low-level reusable
+// function logic between SignatureDefFunction and ConcreteFunction, which
+// present higher level interfaces. This type does *not* hold any "function
+// metadata".
+class FlatTensorFunction {
+ public:
+  // Factory for creating a FlatTensorFunction.
+  //
+  // Params:
+  //  function_def - The function_def associated with the created
+  //                 FlatTensorFunction. FlatTensorFunction will register this
+  //                 function_def with `ctx` on creation, and de-register it on
+  //                 destruction. function_def must be non-null, but
+  //                 otherwise has no lifetime requirements.
+  //  captures - The captured TensorHandles associated with this
+  //             FlatTensorFunction.
+  //  ctx      - A handle to the Tensorflow runtime. This MUST be non-null and
+  //             outlive TFConcreteFunction.
+  //  out      - The output FlatTensorFunction.
+  static Status Create(const FunctionDef* function_def,
+                       std::vector<ImmediateExecutionTensorHandle*> captures,
+                       ImmediateExecutionContext* ctx,
+                       std::unique_ptr<FlatTensorFunction>* out);
+
+  // This method creates a "Call" Op used to execute the function.
+  Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
+                    ImmediateOpPtr* out) const;
+
+  ~FlatTensorFunction();
+
+ private:
+  FlatTensorFunction(const std::string& name,
+                     std::vector<ImmediateExecutionTensorHandle*> captures,
+                     ImmediateExecutionContext* ctx);
+
+  FlatTensorFunction(const FlatTensorFunction&) = delete;
+  FlatTensorFunction& operator=(const FlatTensorFunction&) = delete;
+
+  // Name of the FunctionDef corresponding to this TFConcreteFunction
+  std::string name_;
+  std::vector<ImmediateExecutionTensorHandle*> captures_;
+  ImmediateExecutionContext* ctx_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc
index f734f9e..d9773a4 100644
--- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc
@@ -22,7 +22,7 @@
 #include "tensorflow/c/eager/abstract_tensor_handle.h"
 #include "tensorflow/c/eager/immediate_execution_operation.h"
 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
+#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/platform/errors.h"
@@ -33,32 +33,20 @@
 
 namespace tensorflow {
 
-TFConcreteFunction::TFConcreteFunction(
-    const std::string& name,
-    std::vector<ImmediateExecutionTensorHandle*> captures,
-    FunctionMetadata metadata, ImmediateExecutionContext* ctx)
-    : name_(name),
-      captures_(std::move(captures)),
-      metadata_(std::move(metadata)),
-      ctx_(ctx) {}
-
-TFConcreteFunction::~TFConcreteFunction() {
-  Status status = ctx_->RemoveFunction(name_);
-  if (!status.ok()) {
-    LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
-               << status.error_message();
-  }
-}
+TFConcreteFunction::TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
+                                       FunctionMetadata metadata)
+    : func_(std::move(func)), metadata_(std::move(metadata)) {}
 
 Status TFConcreteFunction::Create(
     const FunctionDef* function_def,
     std::vector<ImmediateExecutionTensorHandle*> captures,
     FunctionMetadata metadata, ImmediateExecutionContext* ctx,
     std::unique_ptr<TFConcreteFunction>* out) {
-  TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
-  out->reset(new TFConcreteFunction(function_def->signature().name(),
-                                    std::move(captures), std::move(metadata),
-                                    ctx));
+  std::unique_ptr<FlatTensorFunction> func;
+  TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
+      function_def, std::move(captures), ctx, &func));
+
+  out->reset(new TFConcreteFunction(std::move(func), std::move(metadata)));
   return Status();
 }
 
@@ -66,30 +54,9 @@
   return metadata_;
 }
 
-Status TFConcreteFunction::GetCallOp(
-    absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) {
-  out->reset(ctx_->CreateOperation());
-  // In eager mode, TF2 python executes functions by constructing an op with
-  // the name of the functiondef:
-  // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
-  // In graph mode, we create a PartitionedCallOp instead:
-  // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
-
-  // TODO(bmzhao): After discussing with Allen, we should execute this via a
-  // PartitionedCallOp for compatibility with "tooling that assumes functions in
-  // graphs are PartitionedCallOps".
-  TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
-
-  // Adding the user-provided inputs to the function.
-  TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
-
-  absl::Span<AbstractTensorHandle* const> captures(
-      reinterpret_cast<AbstractTensorHandle**>(captures_.data()),
-      captures_.size());
-
-  // Adding the captures of the function.
-  TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
-  return Status();
+Status TFConcreteFunction::MakeCallOp(
+    absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
+  return func_->MakeCallOp(inputs, out);
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h
index d38f354..edc26f4 100644
--- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h
@@ -27,7 +27,7 @@
 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
 #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
 #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
-#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
+#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
 
@@ -58,26 +58,22 @@
                        std::unique_ptr<TFConcreteFunction>* out);
 
   // This method returns the "Call" Op used to execute the function.
-  Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
-                   ImmediateOpPtr* out) override;
+  Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
+                    ImmediateOpPtr* out) const override;
 
   const FunctionMetadata& GetFunctionMetadata() const override;
 
-  ~TFConcreteFunction() override;
+  ~TFConcreteFunction() override = default;
 
  private:
-  TFConcreteFunction(const std::string& name,
-                     std::vector<ImmediateExecutionTensorHandle*> captures,
-                     FunctionMetadata metadata, ImmediateExecutionContext* ctx);
+  TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
+                     FunctionMetadata metadata);
 
   TFConcreteFunction(const TFConcreteFunction&) = delete;
   TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
 
-  // Name of the FunctionDef corresponding to this TFConcreteFunction
-  std::string name_;
-  std::vector<ImmediateExecutionTensorHandle*> captures_;
+  std::unique_ptr<FlatTensorFunction> func_;
   FunctionMetadata metadata_;
-  ImmediateExecutionContext* ctx_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc
new file mode 100644
index 0000000..ab1745d
--- /dev/null
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.cc
@@ -0,0 +1,64 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/types/span.h"
+#include "tensorflow/c/eager/abstract_tensor_handle.h"
+#include "tensorflow/c/eager/immediate_execution_operation.h"
+#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
+#include "tensorflow/core/protobuf/struct.pb.h"
+
+namespace tensorflow {
+
+TFSignatureDefFunction::TFSignatureDefFunction(
+    std::unique_ptr<FlatTensorFunction> func,
+    SignatureDefFunctionMetadata metadata)
+    : func_(std::move(func)), metadata_(std::move(metadata)) {}
+
+Status TFSignatureDefFunction::Create(
+    const FunctionDef* function_def,
+    std::vector<ImmediateExecutionTensorHandle*> captures,
+    SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx,
+    std::unique_ptr<TFSignatureDefFunction>* out) {
+  std::unique_ptr<FlatTensorFunction> func;
+  TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
+      function_def, std::move(captures), ctx, &func));
+
+  out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata)));
+  return Status();
+}
+
+const SignatureDefFunctionMetadata&
+TFSignatureDefFunction::GetFunctionMetadata() const {
+  return metadata_;
+}
+
+Status TFSignatureDefFunction::MakeCallOp(
+    absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
+  return func_->MakeCallOp(inputs, out);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h
new file mode 100644
index 0000000..7b56418
--- /dev/null
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h
@@ -0,0 +1,85 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
+#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/c/eager/immediate_execution_context.h"
+#include "tensorflow/c/eager/immediate_execution_operation.h"
+#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
+#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
+#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
+
+namespace tensorflow {
+
+// This is the TF eager runtime implementation of SignatureDefFunction (separate
+// from the TFRT implementation). The user-facing API of SignatureDefFunctions
+// and their semantic differences from ConcreteFunction are described here:
+// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59
+// Additional implementation notes are available here:
+// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48
+class TFSignatureDefFunction : public SignatureDefFunction {
+ public:
+  // Factory function for creating a TFSignatureDefFunction.
+  //
+  // Params:
+  //  function_def - The function_def associated with the created
+  //                 TFSignatureDefFunction. TFSignatureDefFunction will
+  //                 register this function_def with `ctx` on creation, and
+  //                 de-register it on destruction. function_def must be
+  //                 non-null, but otherwise has no lifetime requirements.
+  //  captures - The captured TensorHandles associated with this
+  //             TFConcreteFunction.
+  //  metadata - FunctionMetadata associated with this TFSignatureDefFunction.
+  //  ctx      - A handle to the Tensorflow runtime. This MUST be non-null and
+  //             outlive TFSignatureDefFunction.
+  //  out      - The output TFSignatureDefFunction.
+  static Status Create(const FunctionDef* function_def,
+                       std::vector<ImmediateExecutionTensorHandle*> captures,
+                       SignatureDefFunctionMetadata metadata,
+                       ImmediateExecutionContext* ctx,
+                       std::unique_ptr<TFSignatureDefFunction>* out);
+
+  // This method creates a "Call" Op used to execute the function.
+  Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
+                    ImmediateOpPtr* out) const override;
+
+  const SignatureDefFunctionMetadata& GetFunctionMetadata() const override;
+
+  ~TFSignatureDefFunction() override = default;
+
+ private:
+  TFSignatureDefFunction(std::unique_ptr<FlatTensorFunction> func,
+                         SignatureDefFunctionMetadata metadata);
+
+  TFSignatureDefFunction(const TFSignatureDefFunction&) = delete;
+  TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete;
+
+  std::unique_ptr<FlatTensorFunction> func_;
+  SignatureDefFunctionMetadata metadata_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc
index d831a8d..a212c25 100644
--- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc
@@ -65,10 +65,11 @@
 Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
                                      DataType dtype, TensorShape shape,
                                      absl::optional<std::string> name,
+                                     const char* raw_device_name,
                                      std::unique_ptr<Variable>* output) {
   ImmediateTensorHandlePtr handle;
   TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
-      ctx, dtype, shape, &handle));
+      ctx, dtype, shape, raw_device_name, &handle));
 
   output->reset(
       new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h
index 48ea1d0..13f56fd 100644
--- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h
+++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h
@@ -37,6 +37,7 @@
   static Status CreateUninitialized(ImmediateExecutionContext* ctx,
                                     DataType dtype, TensorShape shape,
                                     absl::optional<std::string> name,
+                                    const char* raw_device_name,
                                     std::unique_ptr<Variable>* output);
 
   // The dtype of the underlying variable.
diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
index 0d97741..e79fd8d 100644
--- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
+++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc
@@ -122,9 +122,9 @@
   tensorflow::TensorShape shape(variable.shape());
   tensorflow::DataType dtype = variable.dtype();
 
-  TF_RETURN_IF_ERROR(
-      Variable::CreateUninitialized(ctx, dtype, shape, name, output));
-
+  TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
+      ctx, dtype, shape, name,
+      variable.device().empty() ? nullptr : variable.device().c_str(), output));
   return Status();
 }
 
diff --git a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc
index cf58e5e..45b0ac0 100644
--- a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc
+++ b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/c/tensor_interface.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/framework/tensor_shape.h"
@@ -38,9 +39,15 @@
 class SavedVariableLoadingTest : public ::testing::TestWithParam<
                                      std::tuple<DataType, std::vector<int64>>> {
  public:
-  SavedVariableLoadingTest()
-      : device_mgr_(testing::CreateTestingDeviceMgr()),
-        ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
+  SavedVariableLoadingTest() {
+    SessionOptions options;
+    options.config.mutable_device_count()->insert({"CPU", 3});
+    std::vector<std::unique_ptr<Device>> devices;
+    TF_CHECK_OK(DeviceFactory::AddDevices(
+        options, "/job:localhost/replica:0/task:0", &devices));
+    device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
+    ctx_ = testing::CreateTestingEagerContext(device_mgr_.get());
+  }
 
   EagerContext* context() { return ctx_.get(); }
 
@@ -67,6 +74,39 @@
   EXPECT_EQ(var->shape(), shape);
 }
 
+// Verify that a device specified in the SavedVariable is kept.
+TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithDevice) {
+  auto& test_params = GetParam();
+  DataType dtype = std::get<0>(test_params);
+  TensorShape shape(std::get<1>(test_params));
+
+  SavedVariable saved_variable;
+  saved_variable.set_dtype(dtype);
+  saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:1"),
+      shape.AsProto(saved_variable.mutable_shape());
+
+  std::unique_ptr<Variable> var;
+  TF_ASSERT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
+  EXPECT_EQ(down_cast<TensorHandle*>(var->handle())->resource_device()->name(),
+            "/job:localhost/replica:0/task:0/device:CPU:1");
+}
+
+// Verify load failure if a non-existing device is specified.
+TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithInvalidDevice) {
+  auto& test_params = GetParam();
+  DataType dtype = std::get<0>(test_params);
+  TensorShape shape(std::get<1>(test_params));
+
+  SavedVariable saved_variable;
+  saved_variable.set_dtype(dtype);
+  saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:99"),
+      shape.AsProto(saved_variable.mutable_shape());
+
+  std::unique_ptr<Variable> var;
+  ASSERT_NE(Status::OK(),
+            internal::LoadSavedVariable(context(), saved_variable, &var));
+}
+
 // Assigning and reading values should yield
 // consistent results.
 TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
@@ -79,7 +119,7 @@
   Status status;
   std::unique_ptr<Variable> var;
   TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
-                                             absl::nullopt, &var));
+                                             absl::nullopt, nullptr, &var));
 
   // Create a TensorHandle
   ImmediateTensorHandlePtr expected_handle =
diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc
index d551919..7c11158 100644
--- a/tensorflow/c/experimental/saved_model/core/test_utils.cc
+++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc
@@ -45,7 +45,6 @@
   return EagerContextPtr(new EagerContext(
       SessionOptions(),
       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
       /* async= */ false,
       /* lazy_copy_function_remote_inputs= */ false, device_mgr,
       /* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc
index 65c6eca..2beed8f 100644
--- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc
+++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc
@@ -34,15 +34,15 @@
       &tensorflow::unwrap(func)->GetFunctionMetadata()));
 }
 
-TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
-                                     TFE_TensorHandle** inputs, int num_inputs,
-                                     TF_Status* status) {
+TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func,
+                                      TFE_TensorHandle** inputs, int num_inputs,
+                                      TF_Status* status) {
   tensorflow::ImmediateOpPtr call_op;
   absl::Span<tensorflow::AbstractTensorHandle* const> input_span(
       reinterpret_cast<tensorflow::AbstractTensorHandle**>(
           tensorflow::unwrap(inputs)),
       static_cast<size_t>(num_inputs));
-  status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op);
+  status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op);
   if (!status->status.ok()) {
     return nullptr;
   }
diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc
index e58b232..df998fc 100644
--- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc
+++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc
@@ -107,7 +107,7 @@
   compute_fn_inputs.push_back(input_a);
   compute_fn_inputs.push_back(input_b);
 
-  TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(
+  TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
       compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
 
diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h
index 0fd0f70..ff8a245 100644
--- a/tensorflow/c/experimental/saved_model/public/concrete_function.h
+++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h
@@ -47,7 +47,7 @@
 // high-level API here. A strawman for what this interface could look like:
 // TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
 // inputs, int num_inputs, TF_Status* status);
-TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
+TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp(
     TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
     TF_Status* status);
 
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc
index 0e55ba3..3b0fb9e 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor.cc
+++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc
@@ -63,12 +63,18 @@
   VALIDATE_MEMBER(SP_Platform, platform, name);
   VALIDATE_MEMBER(SP_Platform, platform, type);
   VALIDATE_MEMBER(SP_Platform, platform, visible_device_count);
-  VALIDATE_MEMBER(SP_Platform, platform, create_device);
-  VALIDATE_MEMBER(SP_Platform, platform, destroy_device);
-  VALIDATE_MEMBER(SP_Platform, platform, create_stream_executor);
-  VALIDATE_MEMBER(SP_Platform, platform, destroy_stream_executor);
-  VALIDATE_MEMBER(SP_Platform, platform, create_timer_fns);
-  VALIDATE_MEMBER(SP_Platform, platform, destroy_timer_fns);
+  return port::Status::OK();
+}
+
+port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
+  VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns,
+                       SP_PLATFORM_FNS_STRUCT_SIZE);
+  VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device);
+  VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device);
+  VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_stream_executor);
+  VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor);
+  VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns);
+  VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns);
   return port::Status::OK();
 }
 
@@ -97,11 +103,18 @@
   return port::Status::OK();
 }
 
-port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se) {
+port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
+                                      const SP_Platform& platform) {
   VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
   VALIDATE_MEMBER(SP_StreamExecutor, se, allocate);
   VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate);
   VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats);
+  VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_allocate);
+  VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_deallocate);
+  if (platform.supports_unified_memory) {
+    VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_allocate);
+    VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_deallocate);
+  }
   VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage);
   VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream);
   VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream);
@@ -131,9 +144,9 @@
   VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
                        SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
   VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform);
+  VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform_fns);
   return port::Status::OK();
 }
-
 #undef VALIDATE_MEMBER
 
 struct TFStatusDeleter {
@@ -297,19 +310,21 @@
 
 class CStreamExecutor : public internal::StreamExecutorInterface {
  public:
-  explicit CStreamExecutor(SP_Device device,
-                           void (*destroy_device)(SP_Device* const device),
-                           SP_StreamExecutor* stream_executor,
+  explicit CStreamExecutor(SP_Device device, SP_StreamExecutor* stream_executor,
+                           SP_Platform* platform, SP_PlatformFns* platform_fns,
                            SP_TimerFns* timer_fns, const std::string& name,
                            int visible_device_count)
       : device_(std::move(device)),
-        destroy_device_(destroy_device),
         stream_executor_(stream_executor),
+        platform_(platform),
+        platform_fns_(platform_fns),
         timer_fns_(timer_fns),
         platform_name_(name),
         visible_device_count_(visible_device_count) {}
 
-  ~CStreamExecutor() override { destroy_device_(&device_); }
+  ~CStreamExecutor() override {
+    platform_fns_->destroy_device(platform_, &device_);
+  }
 
   port::Status Init(int device_ordinal, DeviceOptions device_options) override {
     return port::Status::OK();
@@ -348,6 +363,16 @@
   bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
   bool HostMemoryUnregister(void* mem) override { return false; }
 
+  void* UnifiedMemoryAllocate(uint64 size) override {
+    CHECK(stream_executor_->unified_memory_allocate);
+    return stream_executor_->unified_memory_allocate(&device_, size);
+  }
+
+  void UnifiedMemoryDeallocate(void* mem) override {
+    CHECK(stream_executor_->unified_memory_deallocate);
+    stream_executor_->unified_memory_deallocate(&device_, mem);
+  }
+
   absl::optional<AllocatorStats> GetAllocatorStats() override {
     SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
     TF_Bool has_stats =
@@ -647,6 +672,7 @@
     // TODO(annarev): Figure out if we need to support more description fields.
     internal::DeviceDescriptionBuilder builder;
     builder.set_name(platform_name_);
+    // TODO(annarev): `Also supports_unified_memory` in DeviceDescription.
     return builder.Build();
   }
 
@@ -674,8 +700,9 @@
 
  private:
   SP_Device device_;
-  void (*destroy_device_)(SP_Device* const device);
   SP_StreamExecutor* stream_executor_;
+  SP_Platform* platform_;
+  SP_PlatformFns* platform_fns_;
   SP_TimerFns* timer_fns_;
   std::string platform_name_;
   int visible_device_count_;
@@ -684,18 +711,23 @@
 
 CPlatform::CPlatform(SP_Platform platform,
                      void (*destroy_platform)(SP_Platform*),
+                     SP_PlatformFns platform_fns,
+                     void (*destroy_platform_fns)(SP_PlatformFns*),
                      SP_StreamExecutor stream_executor, SP_TimerFns timer_fns)
     : platform_(std::move(platform)),
       destroy_platform_(destroy_platform),
+      platform_fns_(std::move(platform_fns)),
+      destroy_platform_fns_(destroy_platform_fns),
       stream_executor_(std::move(stream_executor)),
       timer_fns_(std::move(timer_fns)),
       name_(platform.name) {}
 
 CPlatform::~CPlatform() {
   executor_cache_.DestroyAllExecutors();
-  platform_.destroy_stream_executor(&stream_executor_);
-  platform_.destroy_timer_fns(&timer_fns_);
+  platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
+  platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
   destroy_platform_(&platform_);
+  destroy_platform_fns_(&platform_fns_);
 }
 
 port::StatusOr<std::unique_ptr<DeviceDescription>>
@@ -735,12 +767,12 @@
   OwnedTFStatus c_status(TF_NewStatus());
 
   // Create Device
-  platform_.create_device(&device_params, c_status.get());
+  platform_fns_.create_device(&platform_, &device_params, c_status.get());
   TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
   TF_RETURN_IF_ERROR(ValidateSPDevice(device));
 
   auto executor = absl::make_unique<CStreamExecutor>(
-      std::move(device), platform_.destroy_device, &stream_executor_,
+      std::move(device), &stream_executor_, &platform_, &platform_fns_,
       &timer_fns_, name_, platform_.visible_device_count);
   auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
                                                   config.ordinal);
@@ -767,16 +799,19 @@
   SE_PlatformRegistrationParams params{
       SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
   SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
+  SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE};
   params.major_version = SE_MAJOR;
   params.minor_version = SE_MINOR;
   params.revision_version = SE_REVISION;
   params.platform = &platform;
+  params.platform_fns = &platform_fns;
 
   OwnedTFStatus c_status(TF_NewStatus());
   init_fn(&params, c_status.get());
   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
   TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
   TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
+  TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
 
   // Fill stream executor creation params
   SE_CreateStreamExecutorParams se_params{
@@ -785,21 +820,25 @@
   se_params.stream_executor = &se;
 
   // Create StreamExecutor
-  platform.create_stream_executor(&se_params, c_status.get());
+  platform_fns.create_stream_executor(&platform, &se_params, c_status.get());
   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
-  TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se));
+  TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform));
 
   SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
-  platform.create_timer_fns(&timer_fns, c_status.get());
+  platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
+  TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
+  TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
+
+  platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
   TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
 
   // Register new platform
   std::string platform_name = std::string(platform.name);
   std::unique_ptr<stream_executor::CPlatform> cplatform(
-      new stream_executor::CPlatform(std::move(platform),
-                                     params.destroy_platform, std::move(se),
-                                     std::move(timer_fns)));
+      new stream_executor::CPlatform(
+          std::move(platform), params.destroy_platform, std::move(platform_fns),
+          params.destroy_platform_fns, std::move(se), std::move(timer_fns)));
   SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
       std::move(cplatform)));
 
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.h b/tensorflow/c/experimental/stream_executor/stream_executor.h
index db945df..a879f9d 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor.h
+++ b/tensorflow/c/experimental/stream_executor/stream_executor.h
@@ -43,30 +43,42 @@
 //     structs.
 //
 // Example usage:
+//
+//   /* Sample TensorFlow code below, exact implementation might differ. */
+//   // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule
+//   // above and should be set both by core and the plugin."
+//   SP_Device device { SP_DEVICE_STRUCT_SIZE };
+//   SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ;
+//   params.device = &device;
+//
+//   /* Plugin code below */
 //   constexpr char DEVICE_NAME[] = "MyDevice";
 //   constexpr char DEVICE_TYPE[] = "GPU";
 //
-//   void create_device(const SE_CreateDeviceParams* const params,
-//                      TF_Status* const status) {
-//     params->device->struct_size = SP_DEVICE_STRUCT_SIZE;
+//   void create_device(const SP_Platform* platform,
+//                      SE_CreateDeviceParams* params, TF_Status* status) {
+//     // Custom actions based on TensorFlow's view of SP_Device.
+//     OnTFDeviceView(params->device->struct_size);
+//     params->device = { SP_DEVICE_STRUCT_SIZE };
 //     params->device->device_handle = get_my_device_handle(device->ordinal);
 //     params->device->ordinal = params->ordinal;
 //     ...
 //   }
-//   void destroy_device(SP_Device* const device) {
+//
+//   void destroy_device(const SP_Platform* platform, SP_Device* device) {
 //     delete_my_device_handle(device->device_handle);
 //   }
 //
 //   void SE_InitPlugin(
-//       SE_PlatformRegistrationParams* const params,
-//       TF_Status* const status) {
-//     params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
+//       SE_PlatformRegistrationParams* params,
+//       TF_Status* status) {
+//     params->platform = { SP_PLATFORM_STRUCT_SIZE };
 //     // Values such as `name` and `type` must outlive SE_InitPlugin call.
 //     params->platform->name = DEVICE_NAME;
 //     params->platform->type = DEVICE_TYPE;
 //     params->platform->visible_device_count = 2;
-//     params->platform->create_device = create_device;
-//     params->platform->destroy_device = destroy_device;
+//     params->platform_fns->create_device = create_device;
+//     params->platform_fns->destroy_device = destroy_device;
 //     ...
 //   }
 
@@ -155,7 +167,8 @@
   void* ext;        // reserved for future use
   int32_t ordinal;  // device index
 
-  SP_Device* device;  // output, to be filled by plugin
+  SP_Device* device;  // Input/output, struct_size set by TF for plugin to read.
+                      // Subsequently plugin fills the entire struct.
 } SE_CreateDeviceParams;
 
 #define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \
@@ -186,6 +199,17 @@
   // Deallocates a region of host memory allocated by `host_memory_allocate`.
   void (*host_memory_deallocate)(const SP_Device* device, void* mem);
 
+  // Allocates unified memory space of the given size, if supported. Unified
+  // memory support should be added by setting `supports_unified_memory` field
+  // in `SP_Platform`.
+  void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes);
+
+  // Deallocates unified memory space previously allocated with
+  // `unified_memory_allocate`. Unified
+  // memory support should be added by setting `supports_unified_memory` field
+  // in `SP_Platform`.
+  void (*unified_memory_deallocate)(const SP_Device* device, void* location);
+
   // Fills SP_AllocatorStats with allocator statistics, if it is available.
   // If it is not available, return false.
   TF_Bool (*get_allocator_stats)(const SP_Device* device,
@@ -334,27 +358,46 @@
   // Number of visible devices
   size_t visible_device_count;
 
+  // Whether this platform supports unified memory.
+  // Unified memory is a single memory address space accessible from any device.
+  TF_Bool supports_unified_memory;
+} SP_Platform;
+
+#define SP_PLATFORM_STRUCT_SIZE \
+  TF_OFFSET_OF_END(SP_Platform, supports_unified_memory)
+
+typedef struct SP_PlatformFns {
+  size_t struct_size;
+
+  void* ext;  // reserved for future use
+
   // Callbacks for creating/destroying SP_Device.
-  void (*create_device)(const SE_CreateDeviceParams* params, TF_Status* status);
+  void (*create_device)(const SP_Platform* platform,
+                        SE_CreateDeviceParams* params, TF_Status* status);
 
   // Clean up fields inside SP_Device that were allocated
   // by the plugin. `device` itself should not be deleted here.
-  void (*destroy_device)(SP_Device* device);
+  void (*destroy_device)(const SP_Platform* platform, SP_Device* device);
 
   // Callbacks for creating/destroying SP_StreamExecutor.
-  void (*create_stream_executor)(const SE_CreateStreamExecutorParams* params,
+  void (*create_stream_executor)(const SP_Platform* platform,
+                                 SE_CreateStreamExecutorParams* params,
                                  TF_Status* status);
   // Clean up fields inside SP_StreamExecutor that were allocated
   // by the plugin. `stream_executor` itself should not be deleted here.
-  void (*destroy_stream_executor)(SP_StreamExecutor* stream_executor);
+  void (*destroy_stream_executor)(const SP_Platform* platform,
+                                  SP_StreamExecutor* stream_executor);
 
   // Callbacks for creating/destroying SP_TimerFns.
-  void (*create_timer_fns)(SP_TimerFns* timer, TF_Status* status);
+  void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer,
+                           TF_Status* status);
 
-  void (*destroy_timer_fns)(SP_TimerFns* timer_fns);
-} SP_Platform;
+  void (*destroy_timer_fns)(const SP_Platform* platform,
+                            SP_TimerFns* timer_fns);
+} SP_PlatformFns;
 
-#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, destroy_timer_fns)
+#define SP_PLATFORM_FNS_STRUCT_SIZE \
+  TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns)
 
 typedef struct SE_PlatformRegistrationParams {
   size_t struct_size;
@@ -365,14 +408,17 @@
   int32_t minor_version;
   int32_t revision_version;
 
-  SP_Platform* platform;  // output, set by plugin
+  SP_Platform* platform;         // output, set by plugin
+  SP_PlatformFns* platform_fns;  // output, set by plugin
   // Clean up fields inside SP_Platform that were allocated
   // by the plugin. `platform` itself should not be deleted here.
   void (*destroy_platform)(SP_Platform* platform);  // out, set by plugin
+  void (*destroy_platform_fns)(
+      SP_PlatformFns* platform_fns);  // out, set by plugin
 } SE_PlatformRegistrationParams;
 
 #define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \
-  TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform)
+  TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns)
 
 void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status);
 
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
index 2285fe8..557c556 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
+++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h
@@ -40,6 +40,8 @@
  public:
   explicit CPlatform(SP_Platform platform,
                      void (*destroy_platform)(SP_Platform*),
+                     SP_PlatformFns platform_fns,
+                     void (*destroy_platform_fns)(SP_PlatformFns*),
                      SP_StreamExecutor stream_executor, SP_TimerFns timer_fns);
   ~CPlatform() override;
 
@@ -69,6 +71,8 @@
  private:
   SP_Platform platform_;
   void (*destroy_platform_)(SP_Platform*);
+  SP_PlatformFns platform_fns_;
+  void (*destroy_platform_fns_)(SP_PlatformFns*);
   SP_StreamExecutor stream_executor_;
   SP_TimerFns timer_fns_;
   const std::string name_;
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
index 5eddeff..b8efa2d 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
+++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
@@ -50,6 +50,10 @@
               int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
 void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
 }
+void* host_memory_allocate(const SP_Device* const device, uint64_t size) {
+  return nullptr;
+}
+void host_memory_deallocate(const SP_Device* const device, void* mem) {}
 TF_Bool get_allocator_stats(const SP_Device* const device,
                             SP_AllocatorStats* const stats) {
   return true;
@@ -114,6 +118,8 @@
   se->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE;
   se->allocate = allocate;
   se->deallocate = deallocate;
+  se->host_memory_allocate = host_memory_allocate;
+  se->host_memory_deallocate = host_memory_deallocate;
   se->get_allocator_stats = get_allocator_stats;
   se->device_memory_usage = device_memory_usage;
   se->create_stream = create_stream;
@@ -146,48 +152,54 @@
 }
 
 /*** Create SP_Platform ***/
-void create_timer_fns(SP_TimerFns* const timer_fns, TF_Status* const status) {
+void create_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns,
+                      TF_Status* status) {
   TF_SetStatus(status, TF_OK, "");
   PopulateDefaultTimerFns(timer_fns);
 }
-void destroy_timer_fns(SP_TimerFns* const timer_fns) {}
+void destroy_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
 
-void create_stream_executor(const SE_CreateStreamExecutorParams* const params,
-                            TF_Status* const status) {
+void create_stream_executor(const SP_Platform* platform,
+                            SE_CreateStreamExecutorParams* params,
+                            TF_Status* status) {
   TF_SetStatus(status, TF_OK, "");
   PopulateDefaultStreamExecutor(params->stream_executor);
 }
-void destroy_stream_executor(SP_StreamExecutor* const se) {}
+void destroy_stream_executor(const SP_Platform* platform,
+                             SP_StreamExecutor* se) {}
 
-void create_device(const SE_CreateDeviceParams* const params,
-                   TF_Status* const status) {
+void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params,
+                   TF_Status* status) {
   TF_SetStatus(status, TF_OK, "");
   params->device->struct_size = SP_DEVICE_STRUCT_SIZE;
 }
-void destroy_device(SP_Device* const device) {}
+void destroy_device(const SP_Platform* platform, SP_Device* device) {}
 
-void PopulateDefaultPlatform(SP_Platform* platform) {
+void PopulateDefaultPlatform(SP_Platform* platform,
+                             SP_PlatformFns* platform_fns) {
   platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
   platform->name = DEVICE_NAME;
   platform->type = DEVICE_TYPE;
   platform->visible_device_count = DEVICE_COUNT;
-  platform->create_device = create_device;
-  platform->destroy_device = destroy_device;
-  platform->create_stream_executor = create_stream_executor;
-  platform->destroy_stream_executor = destroy_stream_executor;
-  platform->create_timer_fns = create_timer_fns;
-  platform->destroy_timer_fns = destroy_timer_fns;
+  platform_fns->create_device = create_device;
+  platform_fns->destroy_device = destroy_device;
+  platform_fns->create_stream_executor = create_stream_executor;
+  platform_fns->destroy_stream_executor = destroy_stream_executor;
+  platform_fns->create_timer_fns = create_timer_fns;
+  platform_fns->destroy_timer_fns = destroy_timer_fns;
 }
 
 void destroy_platform(SP_Platform* const platform) {}
+void destroy_platform_fns(SP_PlatformFns* const platform_fns) {}
 
 /*** Registration tests ***/
 TEST(StreamExecutor, SuccessfulRegistration) {
   auto plugin_init = [](SE_PlatformRegistrationParams* const params,
                         TF_Status* const status) -> void {
     TF_SetStatus(status, TF_OK, "");
-    PopulateDefaultPlatform(params->platform);
+    PopulateDefaultPlatform(params->platform, params->platform_fns);
     params->destroy_platform = destroy_platform;
+    params->destroy_platform_fns = destroy_platform_fns;
   };
   port::Status status = RegisterDevicePlugin(plugin_init);
   TF_ASSERT_OK(status);
@@ -209,9 +221,10 @@
   auto plugin_init = [](SE_PlatformRegistrationParams* const params,
                         TF_Status* const status) -> void {
     TF_SetStatus(status, TF_OK, "");
-    PopulateDefaultPlatform(params->platform);
+    PopulateDefaultPlatform(params->platform, params->platform_fns);
     params->platform->name = nullptr;
     params->destroy_platform = destroy_platform;
+    params->destroy_platform_fns = destroy_platform_fns;
   };
 
   port::Status status = RegisterDevicePlugin(plugin_init);
@@ -223,15 +236,33 @@
   auto plugin_init = [](SE_PlatformRegistrationParams* const params,
                         TF_Status* const status) -> void {
     TF_SetStatus(status, TF_OK, "");
-    PopulateDefaultPlatform(params->platform);
-    params->platform->create_device = nullptr;
+    PopulateDefaultPlatform(params->platform, params->platform_fns);
+    params->platform_fns->create_device = nullptr;
     params->destroy_platform = destroy_platform;
+    params->destroy_platform_fns = destroy_platform_fns;
   };
 
   port::Status status = RegisterDevicePlugin(plugin_init);
   ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
   ASSERT_EQ(status.error_message(),
-            "'create_device' field in SP_Platform must be set.");
+            "'create_device' field in SP_PlatformFns must be set.");
+}
+
+TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) {
+  auto plugin_init = [](SE_PlatformRegistrationParams* const params,
+                        TF_Status* const status) -> void {
+    TF_SetStatus(status, TF_OK, "");
+    PopulateDefaultPlatform(params->platform, params->platform_fns);
+    params->platform->supports_unified_memory = true;
+    params->destroy_platform = destroy_platform;
+    params->destroy_platform_fns = destroy_platform_fns;
+  };
+
+  port::Status status = RegisterDevicePlugin(plugin_init);
+  ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
+  ASSERT_EQ(
+      status.error_message(),
+      "'unified_memory_allocate' field in SP_StreamExecutor must be set.");
 }
 
 /*** StreamExecutor behavior tests ***/
@@ -239,7 +270,7 @@
  protected:
   StreamExecutorTest() {}
   void SetUp() override {
-    PopulateDefaultPlatform(&platform_);
+    PopulateDefaultPlatform(&platform_, &platform_fns_);
     PopulateDefaultStreamExecutor(&se_);
     PopulateDefaultTimerFns(&timer_fns_);
   }
@@ -247,8 +278,9 @@
 
   StreamExecutor* GetExecutor(int ordinal) {
     if (!cplatform_) {
-      cplatform_ = absl::make_unique<CPlatform>(platform_, destroy_platform,
-                                                se_, timer_fns_);
+      cplatform_ = absl::make_unique<CPlatform>(
+          platform_, destroy_platform, platform_fns_, destroy_platform_fns, se_,
+          timer_fns_);
     }
     port::StatusOr<StreamExecutor*> maybe_executor =
         cplatform_->ExecutorForDevice(ordinal);
@@ -256,6 +288,7 @@
     return maybe_executor.ConsumeValueOrDie();
   }
   SP_Platform platform_;
+  SP_PlatformFns platform_fns_;
   SP_StreamExecutor se_;
   SP_TimerFns timer_fns_;
   std::unique_ptr<CPlatform> cplatform_;
@@ -265,13 +298,13 @@
   se_.allocate = [](const SP_Device* const device, uint64_t size,
                     int64_t memory_space, SP_DeviceMemoryBase* const mem) {
     mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE;
-    mem->opaque = std::malloc(size);
+    mem->opaque = malloc(size);
     mem->size = size;
   };
   se_.deallocate = [](const SP_Device* const device,
                       SP_DeviceMemoryBase* const mem) {
     EXPECT_EQ(mem->size, 2 * sizeof(int));
-    std::free(mem->opaque);
+    free(mem->opaque);
     mem->opaque = nullptr;
     mem->size = 0;
   };
@@ -288,10 +321,10 @@
   static bool deallocate_called = false;
   se_.host_memory_allocate = [](const SP_Device* const device, uint64_t size) {
     allocate_called = true;
-    return std::malloc(size);
+    return malloc(size);
   };
   se_.host_memory_deallocate = [](const SP_Device* const device, void* mem) {
-    std::free(mem);
+    free(mem);
     deallocate_called = true;
   };
   StreamExecutor* executor = GetExecutor(0);
@@ -304,6 +337,28 @@
   ASSERT_TRUE(deallocate_called);
 }
 
+TEST_F(StreamExecutorTest, UnifiedMemoryAllocate) {
+  static bool allocate_called = false;
+  static bool deallocate_called = false;
+  se_.unified_memory_allocate = [](const SP_Device* const device,
+                                   uint64_t size) {
+    allocate_called = true;
+    return malloc(size);
+  };
+  se_.unified_memory_deallocate = [](const SP_Device* const device, void* mem) {
+    free(mem);
+    deallocate_called = true;
+  };
+  StreamExecutor* executor = GetExecutor(0);
+  ASSERT_FALSE(allocate_called);
+  void* mem = executor->UnifiedMemoryAllocate(8);
+  ASSERT_NE(mem, nullptr);
+  ASSERT_TRUE(allocate_called);
+  ASSERT_FALSE(deallocate_called);
+  executor->UnifiedMemoryDeallocate(mem);
+  ASSERT_TRUE(deallocate_called);
+}
+
 TEST_F(StreamExecutorTest, GetAllocatorStats) {
   se_.get_allocator_stats = [](const SP_Device* const device,
                                SP_AllocatorStats* const stat) -> TF_Bool {
diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc
index 0b12b17..ed501b5 100644
--- a/tensorflow/c/kernels.cc
+++ b/tensorflow/c/kernels.cc
@@ -280,6 +280,36 @@
   return tf_tensor;
 }
 
+TF_Tensor* TF_ForwardInputOrAllocateOutput(
+    TF_OpKernelContext* context, int* candidate_input_indices,
+    int num_candidate_input_indices, int output_index, int64_t* output_dims,
+    int output_num_dims, int* forwarded_input, TF_Status* status) {
+  TF_SetStatus(status, TF_OK, "");
+  auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
+
+  static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
+                "64-bit int types should match in size");
+  tensorflow::gtl::ArraySlice<int> input_indices_array(
+      candidate_input_indices, num_candidate_input_indices);
+  tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
+      reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
+  tensorflow::Tensor* output_tensor_pointer;
+  tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
+      input_indices_array, output_index,
+      tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
+      forwarded_input);
+  if (!s.ok()) {
+    ::tensorflow::Set_TF_Status_from_Status(status, s);
+    return nullptr;
+  }
+  TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
+  if (!s.ok()) {
+    ::tensorflow::Set_TF_Status_from_Status(status, s);
+    return nullptr;
+  }
+  return tf_tensor_output;
+}
+
 TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
                            int64_t* dims, int num_dims,
                            TF_AllocatorAttributes* attributes,
diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h
index 15fcf0f..489aa53 100644
--- a/tensorflow/c/kernels.h
+++ b/tensorflow/c/kernels.h
@@ -200,6 +200,17 @@
                                             int64_t* dims, int num_dims,
                                             size_t len, TF_Status* status);
 
+// Tries to forward one of the inputs given in input_indices to
+// output[output_index]. If none of the given inputs can be forwarded, calls
+// allocate_output() to allocate a new output buffer. The index of the
+// forwarded input will be assign to output argument forwarded_input (if it's
+// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
+// -1.
+TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
+    TF_OpKernelContext* context, int* candidate_input_indices,
+    int num_candidate_input_indices, int output_index, int64_t* output_dims,
+    int output_num_dims, int* forwarded_input, TF_Status* status);
+
 // Allocates a temporary Tensor of the specified type and shape. The
 // Tensor must not be used after kernel construction is
 // complete.
diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD
index c6cc3f8..6bb2b347 100644
--- a/tensorflow/c/kernels/BUILD
+++ b/tensorflow/c/kernels/BUILD
@@ -53,6 +53,19 @@
     ],
 )
 
+tf_kernel_library(
+    name = "merge_summary_op",
+    prefix = "merge_summary_op",
+    deps = [
+        "//tensorflow/c:kernels",
+        "//tensorflow/c:tf_status",
+        "//tensorflow/c:tf_tensor",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
 tf_gen_op_libs(
     op_lib_names = ["bitcast"],
     deps = [
@@ -82,6 +95,15 @@
     ],
 )
 
+tf_gen_op_libs(
+    op_lib_names = ["merge_summary"],
+    deps = [
+        "//tensorflow/c:ops",
+        "//tensorflow/c:tf_status",
+        "//tensorflow/core:lib",
+    ],
+)
+
 tf_cc_test(
     name = "bitcast_op_test",
     srcs = ["bitcast_op_test.cc"],
@@ -110,6 +132,23 @@
     ],
 )
 
+tf_cc_test(
+    name = "summary_op_benchmark_test",
+    size = "small",
+    srcs = ["summary_op_benchmark_test.cc"],
+    deps = [
+        ":summary_op",
+        "//tensorflow/c:kernels",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 cc_library(
     name = "tensor_shape_utils",
     srcs = ["tensor_shape_utils.cc"],
@@ -146,6 +185,7 @@
     srcs = [
         "bitcast_op.cc",
         "histogram_summary_op.cc",
+        "merge_summary_op.cc",
         "summary_op.cc",
         "tensor_shape_utils.cc",
         "tensor_shape_utils.h",
@@ -158,6 +198,7 @@
     srcs = [
         "ops/bitcast.cc",
         "ops/histogram_summary.cc",
+        "ops/merge_summary.cc",
         "ops/summary.cc",
     ],
 )
diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc
index 5de5270..143a267 100644
--- a/tensorflow/c/kernels/histogram_summary_op.cc
+++ b/tensorflow/c/kernels/histogram_summary_op.cc
@@ -93,11 +93,13 @@
       std::ostringstream err;
       err << "Nan in summary histogram for: " << k->op_node_name;
       TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
+      TF_OpKernelContext_Failure(ctx, status.get());
       return;
     } else if (Eigen::numext::isinf(double_val)) {
       std::ostringstream err;
       err << "Infinity in Histogram for: " << k->op_node_name;
       TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
+      TF_OpKernelContext_Failure(ctx, status.get());
       return;
     }
     histo.Add(double_val);
diff --git a/tensorflow/c/kernels/merge_summary_op.cc b/tensorflow/c/kernels/merge_summary_op.cc
new file mode 100644
index 0000000..e450293
--- /dev/null
+++ b/tensorflow/c/kernels/merge_summary_op.cc
@@ -0,0 +1,123 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <sstream>
+#include <unordered_set>
+
+#include "tensorflow/c/kernels.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_tensor.h"
+#include "tensorflow/core/framework/selective_registration.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/tstring.h"
+
+namespace {
+
+// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status
+struct TFTensorDeleter {
+  void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
+};
+
+struct TFStatusDeleter {
+  void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
+};
+
+// Struct that wraps TF_Tensor and TF_Status to delete once out of scope
+using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
+using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
+
+// dummy functions used for kernel registration
+void* MergeSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
+
+void MergeSummaryOp_Delete(void* kernel) {}
+
+void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
+  tensorflow::Summary s;
+  std::unordered_set<tensorflow::string> tags;
+  Safe_TF_StatusPtr status(TF_NewStatus());
+  for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) {
+    TF_Tensor* input;
+    TF_GetInput(ctx, input_num, &input, status.get());
+    Safe_TF_TensorPtr safe_input_ptr(input);
+    if (TF_GetCode(status.get()) != TF_OK) {
+      TF_OpKernelContext_Failure(ctx, status.get());
+      return;
+    }
+    auto tags_array =
+        static_cast<tensorflow::tstring*>(TF_TensorData(safe_input_ptr.get()));
+    for (int i = 0; i < TF_TensorElementCount(safe_input_ptr.get()); ++i) {
+      const tensorflow::tstring& s_in = tags_array[i];
+      tensorflow::Summary summary_in;
+      if (!tensorflow::ParseProtoUnlimited(&summary_in, s_in)) {
+        TF_SetStatus(status.get(), TF_INVALID_ARGUMENT,
+                     "Could not parse one of the summary inputs");
+        TF_OpKernelContext_Failure(ctx, status.get());
+        return;
+      }
+      for (int v = 0; v < summary_in.value_size(); ++v) {
+        // This tag is unused by the TensorSummary op, so no need to check for
+        // duplicates.
+        const tensorflow::string& tag = summary_in.value(v).tag();
+        if ((!tag.empty()) && !tags.insert(tag).second) {
+          std::ostringstream err;
+          err << "Duplicate tag " << tag << " found in summary inputs ";
+          TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
+          TF_OpKernelContext_Failure(ctx, status.get());
+          return;
+        }
+        *s.add_value() = summary_in.value(v);
+      }
+    }
+  }
+  Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
+      /*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
+      /*dims=*/nullptr, /*num_dims=*/0,
+      /*len=*/sizeof(tensorflow::tstring), status.get()));
+  if (TF_GetCode(status.get()) != TF_OK) {
+    TF_OpKernelContext_Failure(ctx, status.get());
+    return;
+  }
+  tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
+      TF_TensorData(summary_tensor.get()));
+  CHECK(SerializeToTString(s, output_tstring));
+}
+
+void RegisterMergeSummaryOpKernel() {
+  TF_Status* status = TF_NewStatus();
+  {
+    auto* builder = TF_NewKernelBuilder(
+        "MergeSummary", tensorflow::DEVICE_CPU, &MergeSummaryOp_Create,
+        &MergeSummaryOp_Compute, &MergeSummaryOp_Delete);
+    TF_RegisterKernelBuilder("MergeSummary", builder, status);
+    CHECK_EQ(TF_OK, TF_GetCode(status))
+        << "Error while registering Merge Summmary kernel";
+  }
+  TF_DeleteStatus(status);
+}
+
+// A dummy static variable initialized by a lambda whose side-effect is to
+// register the Histogram Summary kernel.
+TF_ATTRIBUTE_UNUSED static bool IsMergeSummaryOpKernelRegistered = []() {
+  if (SHOULD_REGISTER_OP_KERNEL("MergeSummary")) {
+    RegisterMergeSummaryOpKernel();
+  }
+  return true;
+}();
+
+}  // namespace
diff --git a/tensorflow/c/kernels/ops/merge_summary.cc b/tensorflow/c/kernels/ops/merge_summary.cc
new file mode 100644
index 0000000..991c469
--- /dev/null
+++ b/tensorflow/c/kernels/ops/merge_summary.cc
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/c/ops.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/core/framework/selective_registration.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+static void merge_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
+                                             TF_Status* status) {
+  TF_SetStatus(status, TF_OK, "");
+  TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
+  TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
+  TF_DeleteShapeHandle(result);
+}
+
+void Register_MergeSummaryOp() {
+  TF_Status* status = TF_NewStatus();
+
+  TF_OpDefinitionBuilder* op_builder =
+      TF_NewOpDefinitionBuilder("MergeSummary");
+  TF_OpDefinitionBuilderAddInput(op_builder, "inputs: N * string");
+  TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
+  TF_OpDefinitionBuilderAddAttr(op_builder, "N: int >= 1");
+  TF_OpDefinitionBuilderSetShapeInferenceFunction(
+      op_builder, &merge_summary_shape_inference_fn);
+
+  TF_RegisterOpDefinition(op_builder, status);
+  CHECK_EQ(TF_GetCode(status), TF_OK)
+      << "MergeSummary op registration failed: " << TF_Message(status);
+  TF_DeleteStatus(status);
+}
+
+TF_ATTRIBUTE_UNUSED static bool MergeSummaryOpRegistered = []() {
+  if (SHOULD_REGISTER_OP("MergeSummary")) {
+    Register_MergeSummaryOp();
+  }
+  return true;
+}();
diff --git a/tensorflow/c/kernels/summary_op_benchmark_test.cc b/tensorflow/c/kernels/summary_op_benchmark_test.cc
new file mode 100644
index 0000000..887a860
--- /dev/null
+++ b/tensorflow/c/kernels/summary_op_benchmark_test.cc
@@ -0,0 +1,71 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) {
+  Graph* g = new Graph(OpRegistry::Global());
+  Tensor tags(DT_STRING, shape);
+  Tensor values(DT_FLOAT, shape);
+  for (int i = 0; i < tags.NumElements(); ++i) {
+    tags.flat<tstring>()(i) = tag;
+    values.flat<float>()(i) = value;
+  }
+  Node* ret;
+  TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary")
+                  .Input(test::graph::Constant(g, tags))
+                  .Input(test::graph::Constant(g, values))
+                  .Attr("T", DT_FLOAT)
+                  .Finalize(g, &ret));
+  return g;
+}
+
+// Macro used to parse initializer list for tensorshape
+#define DIMARGS(...) \
+  { __VA_ARGS__ }
+// // Random parameters for testing
+constexpr char longTagParam[] = "LONGTAG____________________________";
+constexpr float largeValueParam = 2352352.2623433;
+
+#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
+  void BM_ScalarSummary##name##device(int iters) {          \
+    testing::StopTiming();                                  \
+    TensorShape tensorshape(DIMARGS dims);                  \
+    auto g = BM_ScalarSummaryOp(tensorshape, #tag, value);  \
+    testing::StartTiming();                                 \
+    test::Benchmark("cpu", g).Run(iters);                   \
+  }                                                         \
+  BENCHMARK(BM_ScalarSummary##name##device);
+
+BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2);
+// Benchmark for large shapes
+BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2);
+// Benchmark for large tag tstring
+BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2);
+// Benchmark for large values
+BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam);
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc
index e8223e4..c9df2cc 100644
--- a/tensorflow/c/kernels_test.cc
+++ b/tensorflow/c/kernels_test.cc
@@ -565,6 +565,74 @@
             output->DebugString(100));
 }
 
+TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
+  const char* node_name = "TestForwardInputOrAllocateOutputKernel";
+  const char* op_name = "BazOp";
+  const char* device_name = "FakeDeviceName";
+
+  REGISTER_OP(op_name)
+      .Input("input1: float")
+      .Input("input2: float")
+      .Output("output1: float")
+      .Attr("SomeDataTypeAttr: type");
+
+  // A kernel whose Compute function that forwards a scalar input to output
+  auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
+    TF_Status* s = TF_NewStatus();
+    int candidate_input_indices[1] = {0};
+    int forwarded_input;
+    int64_t output_dims[1] = {};
+    TF_Tensor* output = TF_ForwardInputOrAllocateOutput(
+        /*context=*/ctx, candidate_input_indices,
+        /*num_candidate_input_indices=*/1,
+        /*output_index=*/0, output_dims, /*output_num_dims=*/0,
+        &forwarded_input, /*status=*/s);
+    EXPECT_EQ(TF_OK, TF_GetCode(s));
+    EXPECT_EQ(forwarded_input, 0);
+    EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
+    EXPECT_EQ(0, TF_NumDims(output));
+    TF_DeleteStatus(s);
+    TF_DeleteTensor(output);
+  };
+
+  TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
+                                                  my_compute_func, nullptr);
+
+  {
+    TF_Status* status = TF_NewStatus();
+    TF_RegisterKernelBuilder(node_name, builder, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status));
+    TF_DeleteStatus(status);
+  }
+
+  {
+    OpKernelContext::Params p;
+    DummyDevice dummy_device(nullptr);
+    p.device = &dummy_device;
+    AllocatorAttributes alloc_attrs;
+    p.output_attr_array = &alloc_attrs;
+
+    Tensor t(123.0f);
+
+    gtl::InlinedVector<TensorValue, 4> inputs;
+    // GetFakeKernel requires a NodeDef with two inputs
+    inputs.emplace_back(&t);
+    inputs.emplace_back();
+    p.inputs = &inputs;
+
+    Status status;
+    std::unique_ptr<OpKernel> kernel =
+        GetFakeKernel(device_name, op_name, node_name, &status);
+    TF_EXPECT_OK(status);
+    ASSERT_NE(nullptr, kernel.get());
+
+    p.op_kernel = kernel.get();
+    OpKernelContext ctx(&p);
+    kernel->Compute(&ctx);
+    ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
+  }
+}
+
 void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
                      TF_DataType dtype) {
   EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
diff --git a/tensorflow/c/logging.cc b/tensorflow/c/logging.cc
index bf6bf06..13c9e6a 100644
--- a/tensorflow/c/logging.cc
+++ b/tensorflow/c/logging.cc
@@ -28,6 +28,7 @@
   va_list args;
   va_start(args, fmt);
   auto message = BuildMessage(fmt, args);
+  va_end(args);
   switch (level) {
     case TF_INFO:
       LOG(INFO) << message;
@@ -48,6 +49,7 @@
   va_list args;
   va_start(args, fmt);
   auto message = BuildMessage(fmt, args);
+  va_end(args);
   VLOG(level) << message;
 }
 
@@ -55,5 +57,6 @@
   va_list args;
   va_start(args, fmt);
   auto message = BuildMessage(fmt, args);
+  va_end(args);
   DVLOG(level) << message;
 }
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 35c6a8b..5992b45 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -329,6 +329,7 @@
     srcs = ["xla_compilation_cache.cc"],
     hdrs = ["xla_compilation_cache.h"],
     deps = [
+        ":flags",
         ":xla_activity_listener",
         ":xla_activity_proto_cc",
         "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
@@ -361,8 +362,11 @@
         "xla_compilation_cache_test.cc",
     ],
     deps = [
+        ":flags",
         ":xla_compilation_cache",
+        ":xla_cpu_jit",
         "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
     ],
@@ -918,6 +922,7 @@
         ":xla_cpu_jit",
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:function_ops",
+        "//tensorflow/cc:functional_ops",
         "//tensorflow/cc:ops",
         "//tensorflow/cc:scope",
         "//tensorflow/compiler/tf2xla:test_util",
diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc
index 6d4bc51..5575320 100644
--- a/tensorflow/compiler/jit/compilability_check_util.cc
+++ b/tensorflow/compiler/jit/compilability_check_util.cc
@@ -518,10 +518,15 @@
   }
 }
 
+// Returns `true` iff node has a given `attr` set to `true`. Returns `false`
+// both for the missing attr, and the attr set to `false`.
+static bool HasBoolAttr(const NodeDef& node, const char* attr) {
+  const auto& it = node.attr().find(attr);
+  return it != node.attr().end() && it->second.b();
+}
+
 bool CanCreateXlaKernel(const NodeDef& node_def) {
-  // If kXlaMustCompileAttr is set on the node_def, use its value.
-  const auto& it = node_def.attr().find(kXlaMustCompileAttr);
-  return it != node_def.attr().end() && it->second.b();
+  return HasBoolAttr(node_def, kXlaMustCompileAttr);
 }
 
 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
@@ -564,4 +569,58 @@
   return Status::OK();
 }
 
+static auto const ops_triggering_xla_compilation =
+    new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
+                                         "XlaConv",
+                                         "XlaDequantize",
+                                         "XlaDot",
+                                         "XlaDynamicSlice",
+                                         "XlaDynamicUpdateSlice",
+                                         "XlaEinsum",
+                                         "XlaGather",
+                                         "XlaIf",
+                                         "XlaKeyValueSort",
+                                         "XlaPad",
+                                         "XlaRecv",
+                                         "XlaReduce",
+                                         "XlaReduceWindow",
+                                         "XlaReplicaId",
+                                         "XlaScatter",
+                                         "XlaSelectAndScatter",
+                                         "XlaSelfAdjointEig",
+                                         "XlaSend",
+                                         "XlaSharding",
+                                         "XlaSort",
+                                         "XlaSpmdFullToShardShape",
+                                         "XlaSpmdShardToFullShape",
+                                         "XlaSvd",
+                                         "XlaWhile"};
+
+static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
+  return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
+         HasBoolAttr(node, kXlaMustCompileAttr) ||
+         HasBoolAttr(node, kXlaCompileAttr) ||
+         HasBoolAttr(node, kXlaScopeAttr) ||
+         HasBoolAttr(node, kXlaInternalScopeAttr) ||
+         ops_triggering_xla_compilation->count(node.op());
+}
+
+bool CanTriggerXlaCompilation(const GraphDef& graph) {
+  for (const FunctionDef& function : graph.library().function()) {
+    for (const NodeDef& node : function.node_def()) {
+      if (NodeCanTriggerXlaCompilation(node)) {
+        return true;
+      }
+    }
+  }
+
+  for (const NodeDef& node : graph.node()) {
+    if (NodeCanTriggerXlaCompilation(node)) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h
index 3b20784..384367c 100644
--- a/tensorflow/compiler/jit/compilability_check_util.h
+++ b/tensorflow/compiler/jit/compilability_check_util.h
@@ -126,9 +126,10 @@
     bool allow_inaccurate_ops = false;
   };
 
-  RecursiveCompilabilityChecker(const OperationFilter* op_filter,
-                                const DeviceType* jit_device_type)
-      : op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
+  RecursiveCompilabilityChecker(OperationFilter op_filter,
+                                DeviceType jit_device_type)
+      : op_filter_(std::move(op_filter)),
+        jit_device_type_(std::move(jit_device_type)) {}
 
   using UncompilableNodesMap =
       std::map<std::string,
@@ -259,8 +260,8 @@
   // Make sure we don't recurse infinitely on recursive functions.
   const size_t kMaxRecursionDepth = 10;
 
-  const OperationFilter& op_filter_;
-  const DeviceType& jit_device_type_;
+  const OperationFilter op_filter_;
+  const DeviceType jit_device_type_;
 };
 
 RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
@@ -282,6 +283,9 @@
 // set.
 bool CanCreateXlaKernel(const NodeDef& node_def);
 
+// Check whether graph can trigger XLA compilation.
+bool CanTriggerXlaCompilation(const GraphDef& graph);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc
index 3ea38e6..3851c66 100644
--- a/tensorflow/compiler/jit/compilability_check_util_test.cc
+++ b/tensorflow/compiler/jit/compilability_check_util_test.cc
@@ -18,6 +18,7 @@
 #include "absl/memory/memory.h"
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -75,8 +76,8 @@
     op_filter_.allow_inaccurate_ops = false;
     op_filter_.allow_slow_ops = false;
 
-    checker_ = absl::make_unique<RecursiveCompilabilityChecker>(&op_filter_,
-                                                                &device_type_);
+    checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
+                                                                device_type_);
   }
 
   FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
@@ -354,5 +355,110 @@
                                 "unsupported op"));
 }
 
+TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
+  GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+  Scope root = Scope::NewRootScope().ExitOnError();
+  FunctionDefLibrary library;
+
+  FunctionDef identity_func = FunctionDefHelper::Create(
+      "IdentityFunc",
+      /*in_def=*/{"x:float"},
+      /*out_def=*/{"res:float"},
+      /*attr_def=*/{},
+      /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
+      /*ret_def*/ {{"res", "t0:output"}});
+
+  *library.add_function() = identity_func;
+
+  Output in = ops::Placeholder(root, DT_FLOAT);
+  NameAttrList b_name_attr;
+  b_name_attr.set_name("IdentityFunc");
+  ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
+                            b_name_attr);
+
+  GraphDef graph_def;
+  TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
+  TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+  EXPECT_FALSE(CanTriggerXlaCompilation(graph_def));
+}
+
+TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) {
+  GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+  Scope root = Scope::NewRootScope().ExitOnError();
+  FunctionDefLibrary library;
+
+  FunctionDef sort_func = FunctionDefHelper::Create(
+      "SortFunc",
+      /*in_def=*/{"x:float"},
+      /*out_def=*/{"res:float"},
+      /*attr_def=*/{},
+      /*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}},
+      /*ret_def*/ {{"res", "t0:output"}});
+
+  *library.add_function() = sort_func;
+
+  Output in = ops::Placeholder(root, DT_FLOAT);
+  NameAttrList b_name_attr;
+  b_name_attr.set_name("SortFunc");
+  ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
+                            b_name_attr);
+
+  GraphDef graph_def;
+  TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
+  TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+  EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
+}
+
+TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) {
+  GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+  Scope root = Scope::NewRootScope().ExitOnError();
+  FunctionDefLibrary library;
+
+  AttrValue true_attribute;
+  true_attribute.set_b(true);
+
+  FunctionDef identity_func = FunctionDefHelper::Create(
+      "IdentityFunc",
+      /*in_def=*/{"x:float"},
+      /*out_def=*/{"res:float"},
+      /*attr_def=*/{},
+      /*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
+      /*ret_def*/ {{"res", "t0:output"}});
+
+  (*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
+
+  FunctionDef call_identity = FunctionDefHelper::Create(
+      "CallIdentity",
+      /*in_def=*/{"x:float"},
+      /*out_def=*/{"z:float"}, /*attr_def=*/{},
+      /*node_def=*/
+      {{{"func_call"},
+        "PartitionedCall",
+        {"x"},
+        {{"Tin", DataTypeSlice({DT_FLOAT})},
+         {"Tout", DataTypeSlice({DT_FLOAT})},
+         {"f",
+          FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})},
+         {kXlaMustCompileAttr, true}}}},
+      /*ret_def=*/{{"z", "func_call:output:0"}});
+
+  *library.add_function() = identity_func;
+  *library.add_function() = call_identity;
+
+  Output in = ops::Placeholder(root, DT_FLOAT);
+  NameAttrList b_name_attr;
+  b_name_attr.set_name("CallIdentity");
+  ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
+                            b_name_attr);
+
+  GraphDef graph_def;
+  TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
+  TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+  EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc
index 4bea71e..84e1e36 100644
--- a/tensorflow/compiler/jit/defs.cc
+++ b/tensorflow/compiler/jit/defs.cc
@@ -28,4 +28,6 @@
 // only when auto_jit is ON.
 const char* const kXlaInternalScopeAttr = "_XlaInternalScope";
 
+const char* const kXlaClusterIdAttr = "_xla_compile_id";
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h
index 9eb4c2c..fa983db 100644
--- a/tensorflow/compiler/jit/defs.h
+++ b/tensorflow/compiler/jit/defs.h
@@ -35,6 +35,9 @@
 extern const char* const kXlaScopeAttr;    // "_XlaScope"
 extern const char* const kXlaInternalScopeAttr;  // "_XlaInternalScope"
 
+// The id of the compiled cluster.
+extern const char* const kXlaClusterIdAttr;  // "_xla_compile_id"
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_JIT_DEFS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index ed25baa..4a5c79c 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -20,6 +20,7 @@
 #include "absl/memory/memory.h"
 #include "absl/strings/ascii.h"
 #include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/framework/node_def.pb.h"
@@ -34,9 +35,6 @@
 
 namespace tensorflow {
 
-const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
-    "_xla_compile_id";
-
 namespace {
 
 const char* const kXlaClusterOutput = "XlaClusterOutput";
@@ -45,10 +43,7 @@
   for (Node* n : graph->nodes()) {
     string name;
     // Only consider nodes being compiled.
-    if (!GetNodeAttr(n->attrs(),
-                     EncapsulateXlaComputationsPass::kXlaClusterAttr, &name)
-             .ok())
-      continue;
+    if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue;
     // Early return for any node with a device that is not a CPU or GPU.
     DeviceNameUtils::ParsedName parsed;
     if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) {
@@ -180,8 +175,7 @@
     retvals[i]->AddAttr("index", i);
   }
 
-  AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
-              call_def);
+  AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def);
   AddNodeAttr("_variable_start_index", variable_start_index, call_def);
 
   // Uniquify the function name.
@@ -216,8 +210,8 @@
   // O(n) pass over the edges.
   for (const Edge* e : (*graph)->edges()) {
     if (!e->IsControlEdge() &&
-        e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
-        e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
+        e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr &&
+        e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr &&
         e->dst()->type_string() != kXlaClusterOutput) {
       return errors::InvalidArgument(
           "Undeclared output of XLA computation. Some common causes of this "
@@ -232,9 +226,9 @@
 
   auto output = absl::make_unique<Graph>((*graph)->op_registry());
   TF_RETURN_WITH_CONTEXT_IF_ERROR(
-      EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
-                                      /*reuse_existing_functions=*/true,
-                                      &output, flib_def),
+      EncapsulateSubgraphsInFunctions(
+          kXlaClusterIdAttr, **graph, RewriteSubgraph,
+          /*reuse_existing_functions=*/true, &output, flib_def),
       "EncapsulateXlaComputationsPass failed");
   graph->swap(output);
   return Status::OK();
@@ -246,7 +240,7 @@
   // while iterating.
   std::vector<Node*> launch_nodes;
   for (Node* n : graph->nodes()) {
-    const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
+    const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr);
     if (!name.empty()) {
       launch_nodes.push_back(n);
     }
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
index 3057e4c..9931b23 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -34,8 +34,6 @@
 // XlaLaunch operators.
 class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
  public:
-  static const char* const kXlaClusterAttr;  // _xla_compile_id
-
   Status Run(const GraphOptimizationPassOptions& options) override;
 
   // The following methods are public only for unit tests.
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
index cc17703..61c9a3f 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/cc/ops/function_ops.h"
 #include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
 #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
 #include "tensorflow/compiler/tf2xla/test_util.h"
@@ -46,19 +47,18 @@
   auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
 
   NodeDef def;
-  TF_CHECK_OK(
-      NodeDefBuilder("launch0", function, &flib_def)
-          .Input(a.node()->name(), 0, DT_INT32)
-          .Input(b.node()->name(), 0, DT_FLOAT)
-          .Input(c.node()->name(), 0, DT_INT32)
-          .Input(d.node()->name(), 0, DT_FLOAT)
-          .Input(u.node()->name(), 0, DT_RESOURCE)
-          .Input(v.node()->name(), 0, DT_RESOURCE)
-          .Input(w.node()->name(), 0, DT_RESOURCE)
-          .Device("/gpu:0")
-          .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
-          .Attr("_variable_start_index", 4)
-          .Finalize(&def));
+  TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def)
+                  .Input(a.node()->name(), 0, DT_INT32)
+                  .Input(b.node()->name(), 0, DT_FLOAT)
+                  .Input(c.node()->name(), 0, DT_INT32)
+                  .Input(d.node()->name(), 0, DT_FLOAT)
+                  .Input(u.node()->name(), 0, DT_RESOURCE)
+                  .Input(v.node()->name(), 0, DT_RESOURCE)
+                  .Input(w.node()->name(), 0, DT_RESOURCE)
+                  .Device("/gpu:0")
+                  .Attr(kXlaClusterIdAttr, "launch0")
+                  .Attr("_variable_start_index", 4)
+                  .Finalize(&def));
 
   Status status;
   Node* launch = scope.graph()->AddNode(def, &status);
@@ -107,7 +107,7 @@
   auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
 
   auto add_attrs = [](Node* node) {
-    node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+    node->AddAttr(kXlaClusterIdAttr, "launch0");
     node->set_requested_device("/gpu:0");
   };
 
@@ -155,8 +155,7 @@
                                     : ops::Add(scope.WithOpName("E"), a1, a0);
 
       auto add_attrs = [](Node* node) {
-        node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
-                      "launch0");
+        node->AddAttr(kXlaClusterIdAttr, "launch0");
       };
       add_attrs(e.node());
 
@@ -216,7 +215,7 @@
     auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
 
     auto add_attrs = [](Node* node) {
-      node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+      node->AddAttr(kXlaClusterIdAttr, "launch0");
       node->set_requested_device("/gpu:0");
     };
 
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
index a4a750b..ee7daf0 100644
--- a/tensorflow/compiler/jit/flags.cc
+++ b/tensorflow/compiler/jit/flags.cc
@@ -268,4 +268,10 @@
   AppendMarkForCompilationPassFlagsInternal(flag_list);
 }
 
+static std::atomic<bool> xla_compilation_disabled(false);
+
+void DisableXlaCompilation() { xla_compilation_disabled = true; }
+
+bool FailOnXlaCompilation() { return xla_compilation_disabled; }
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h
index 6c54fc8..5612b3b 100644
--- a/tensorflow/compiler/jit/flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -162,6 +162,13 @@
 void AppendMarkForCompilationPassFlags(
     std::vector<tensorflow::Flag>* flag_list);
 
+// Disables XLA compilation, forces it to return an error message instead. Can
+// be used by a server to ensure that JIT compilation is opt-in.
+void DisableXlaCompilation();
+
+// Returns `false` unless `DisableXlaCompilation` was called.
+bool FailOnXlaCompilation();
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_JIT_FLAGS_H_
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index de46292..79f1e47 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -158,7 +158,7 @@
       constants_(constants),
       resources_(resources),
       function_(function),
-      platform_info_(XlaPlatformInfoFromContext(ctx)),
+      platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
       has_ref_vars_(has_ref_vars) {}
 
 static Status CompileToLocalExecutable(
@@ -180,7 +180,7 @@
   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
       rm->default_container(), "xla_cache", &cache,
       [&](XlaCompilationCache** cache) {
-        return BuildXlaCompilationCache(ctx, platform_info, cache);
+        return BuildXlaCompilationCache(ctx->device(), platform_info, cache);
       }));
   // Hold the reference to the JIT during evaluation. (We could probably
   // free it sooner because the ResourceMgr will retain a reference, but
@@ -191,7 +191,9 @@
 
   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
   XlaCompiler::Options options = GenerateCompilerOptions(
-      *cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
+      *cache, *ctx->function_library(), ctx->device(),
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
+      platform_info, has_ref_vars, &tf_allocator_adapter);
 
   std::map<int, Tensor> constant_args;
   for (int i : constants) {
@@ -248,8 +250,10 @@
   VLOG(1) << "Executing XLA Computation...";
 
   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
-  se::DeviceMemoryAllocator* allocator =
-      GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
+  se::DeviceMemoryAllocator* allocator = GetAllocator(
+      &tf_allocator_adapter, ctx->device(),
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
+      platform_info_);
   int device_ordinal = stream ? stream->parent()->device_ordinal()
                               : client->default_device_ordinal();
   XlaComputationLaunchContext launch_context(
@@ -373,7 +377,7 @@
       constants_(ConstantsVector(ctx)),
       resources_(ResourcesVector(ctx)),
       function_(FunctionAttr(ctx)),
-      platform_info_(XlaPlatformInfoFromContext(ctx)),
+      platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
       must_compile_(MustCompileAttr(ctx)),
       has_ref_vars_(HasRefVars(ctx)) {}
 
@@ -461,7 +465,7 @@
 }
 
 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
-    : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
+    : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
 
 void XlaRunOp::Compute(OpKernelContext* ctx) {
   VLOG(3) << "XlaRunOp " << def().name();
@@ -472,8 +476,10 @@
       XlaExecutableClosureStore::Global()->Consume(key);
 
   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
-  se::DeviceMemoryAllocator* allocator =
-      GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
+  se::DeviceMemoryAllocator* allocator = GetAllocator(
+      &tf_allocator_adapter, ctx->device(),
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
+      platform_info_);
   se::Stream* stream =
       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
   int device_ordinal = stream ? stream->parent()->device_ordinal()
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 19eb61b..03ac7b0 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -1196,12 +1196,9 @@
       continue;
     }
 
-    DeviceType jit_device_type(registration->compilation_device_name);
-
-    RecursiveCompilabilityChecker::OperationFilter op_filter =
-        CreateOperationFilter(*registration);
-
-    if (!RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
+    if (!RecursiveCompilabilityChecker{
+            CreateOperationFilter(*registration),
+            DeviceType{registration->compilation_device_name}}
              .IsCompilableNode(*node, lib_runtime)) {
       continue;
     }
@@ -1718,7 +1715,6 @@
   const XlaOpRegistry::DeviceRegistration* registration;
   CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
                                             &registration));
-  DeviceType jit_device_type(registration->compilation_device_name);
 
   // We can always *compile* resource operations, stateful RNGs and dummy ops,
   // even if we are sometimes unable to auto-cluster them.
@@ -1733,7 +1729,8 @@
   op_filter.allow_slow_ops = true;
   op_filter.allow_inaccurate_ops = true;
 
-  RecursiveCompilabilityChecker checker{&op_filter, &jit_device_type};
+  RecursiveCompilabilityChecker checker{
+      op_filter, DeviceType{registration->compilation_device_name}};
   if (!uncompilable_node_info) {
     // We do not need uncompilable node info. Just return the result.
     return checker.IsCompilableCall(ndef, flr);
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 971a538..b5bb2fa 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -20,6 +20,7 @@
 #include "absl/base/call_once.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
+#include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/jit/xla_activity.pb.h"
 #include "tensorflow/compiler/jit/xla_activity_listener.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
@@ -323,6 +324,10 @@
     absl::optional<int64> compile_threshold,
     const XlaCompiler::CompilationResult** out_compilation_result,
     xla::LocalExecutable** out_executable) {
+  if (FailOnXlaCompilation()) {
+    return errors::Internal("XLA compilation disabled");
+  }
+
   DCHECK_NE(out_executable, nullptr);
   VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
 
diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc
index 7227615..5578925 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache_test.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc
@@ -15,7 +15,9 @@
 
 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
 
+#include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 
@@ -52,6 +54,30 @@
   }
 }
 
+TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) {
+  NameAttrList fn;
+  fn.set_name("afunction");
+
+  DisableXlaCompilation();
+
+  xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
+  DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT);
+
+  const XlaCompiler::CompilationResult* compilation_result;
+  xla::LocalExecutable* executable;
+
+  auto cache = new XlaCompilationCache(client, device_type);
+  core::ScopedUnref cache_ref(cache);
+
+  Status status = cache->Compile(XlaCompiler::Options{}, fn, {},
+                                 XlaCompiler::CompileOptions{},
+                                 XlaCompilationCache::CompileMode::kStrict,
+                                 &compilation_result, &executable);
+  EXPECT_FALSE(status.ok());
+  EXPECT_TRUE(
+      absl::StrContains(status.error_message(), "XLA compilation disabled"));
+}
+
 static void BM_BuildSignature(int iters, int n_args) {
   NameAttrList fn;
   fn.set_name("afunction");
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index da251c2..ba20b53 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -49,8 +49,10 @@
   xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
 
   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
-  se::DeviceMemoryAllocator* allocator =
-      GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
+  se::DeviceMemoryAllocator* allocator = GetAllocator(
+      &tf_allocator_adapter, ctx->device(),
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
+      platform_info_);
   XlaComputationLaunchContext launch_context(
       client, allocator, client->default_device_ordinal(),
       /*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
@@ -157,13 +159,16 @@
   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
       rm->default_container(), "xla_cache", cache,
       [&](XlaCompilationCache** write_into_cache) {
-        return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache);
+        return BuildXlaCompilationCache(ctx->device(), platform_info_,
+                                        write_into_cache);
       }));
 
   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
-  XlaCompiler::Options options =
-      GenerateCompilerOptions(**cache, ctx, platform_info_,
-                              /*has_ref_vars=*/true, &tf_allocator_adapter);
+  XlaCompiler::Options options = GenerateCompilerOptions(
+      **cache, *ctx->function_library(), ctx->device(),
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
+      platform_info_,
+      /*has_ref_vars=*/true, &tf_allocator_adapter);
 
   XlaCompiler::CompileOptions compile_options;
   compile_options.is_entry_computation = true;
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h
index 095d342..bb8ab88 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h
@@ -37,7 +37,8 @@
 class XlaCompileOnDemandOp : public OpKernel {
  public:
   explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
-      : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
+      : OpKernel(ctx),
+        platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
   void Compute(OpKernelContext* ctx) override;
 
  private:
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 446cd89..dd1ddb6 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -51,7 +51,7 @@
     std::vector<std::unique_ptr<Device>>* devices) {
   XlaDeviceFlags* flags = GetXlaDeviceFlags();
   if (!flags->tf_xla_enable_xla_devices) {
-    LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
+    VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
     return Status::OK();
   }
   bool compile_on_demand = flags->tf_xla_compile_on_demand;
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index f7e7ee9..6d6086c 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -94,6 +94,11 @@
   static Status GetMetadata(OpKernelConstruction* ctx,
                             const Metadata** metadata);
 
+  // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
+  // `device`.
+  static Status GetMetadataFromDevice(DeviceBase* device,
+                                      const XlaDevice::Metadata** metadata);
+
   struct Options {
     // The StreamExecutor platform. Not owned. Must be non-null.
     se::Platform* platform = nullptr;
@@ -196,8 +201,6 @@
   xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
   GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
-  static Status GetMetadataFromDevice(DeviceBase* device,
-                                      const XlaDevice::Metadata** metadata);
 
   Status MakeTensorFromProto(XlaDeviceContext* device_context,
                              const TensorProto& tensor_proto,
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 16f496d..99ba565 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -66,7 +66,7 @@
 Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
   XlaDeviceFlags* flags = GetXlaDeviceFlags();
   if (!flags->tf_xla_enable_xla_devices) {
-    LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
+    VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
     return Status::OK();
   }
 
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 19e2b5a..ed6e399 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -44,12 +44,6 @@
 using xla::ScopedShapedBuffer;
 using xla::ShapedBuffer;
 
-const char kPossibleNonVariableResourceHintMessage[] =
-    "If the error is similar to `Trying to access resource using the wrong "
-    "type`, this is likely because XLA only accepts Resource Variables as "
-    "inputs by snapshotting their values. Other TensorFlow resource types like "
-    "TensorList/TensorArray/Stack are not supported. Try removing non-variable "
-    "resource inputs to XLA.";
 }  // anonymous namespace
 
 VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc
index a5e12b3..b38bf92 100644
--- a/tensorflow/compiler/jit/xla_platform_info.cc
+++ b/tensorflow/compiler/jit/xla_platform_info.cc
@@ -19,7 +19,7 @@
 
 namespace tensorflow {
 
-Status BuildXlaCompilationCache(OpKernelContext* ctx,
+Status BuildXlaCompilationCache(DeviceBase* device,
                                 const XlaPlatformInfo& platform_info,
                                 XlaCompilationCache** cache) {
   if (platform_info.xla_device_metadata()) {
@@ -59,7 +59,7 @@
   xla::LocalClientOptions client_options;
   client_options.set_platform(platform.ValueOrDie());
   client_options.set_intra_op_parallelism_threads(
-      ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+      device->tensorflow_cpu_worker_threads()->num_threads);
   auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
   if (!client.ok()) {
     return client.status();
@@ -75,21 +75,21 @@
   return Status::OK();
 }
 
-XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
-  DeviceType device_type = ctx->device_type();
+XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
+  auto device = static_cast<Device*>(device_base);
   se::Platform::Id platform_id = nullptr;
   const XlaDevice::Metadata* xla_device_metadata = nullptr;
   se::DeviceMemoryAllocator* custom_allocator = nullptr;
 
-  if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+  if (device->device_type() == DEVICE_CPU) {
     platform_id = se::host::kHostPlatformId;
-  } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
-    platform_id = ctx->device()
-                      ->tensorflow_gpu_device_info()
+  } else if (device->device_type() == DEVICE_GPU) {
+    platform_id = device->tensorflow_gpu_device_info()
                       ->stream->parent()
                       ->platform()
                       ->id();
-  } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
+  } else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
+                 .ok()) {
     // If we are on an XlaDevice, use the underlying XLA platform's allocator
     // directly. We could use the StreamExecutor's allocator which may
     // theoretically be more correct, but XLA returns a nice OOM message in a
@@ -104,47 +104,46 @@
         xla_device_metadata->client()->backend().memory_allocator();
   }
 
-  return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
-                         custom_allocator);
+  return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
+                         xla_device_metadata, custom_allocator);
 }
 
 se::DeviceMemoryAllocator* GetAllocator(
     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
-    OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
+    DeviceBase* device, se::Stream* stream,
+    const XlaPlatformInfo& platform_info) {
   if (platform_info.custom_allocator()) {
     return platform_info.custom_allocator();
   }
-  if (!ctx->op_device_context()) {
+  if (!stream) {
     // Stream is not set for the host platform.
     se::Platform* platform =
         se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
             .ValueOrDie();
-    tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
+    tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
     return &tf_allocator_adapter->value();
   }
-  tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
-                                ctx->op_device_context()->stream());
+  tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
   return &tf_allocator_adapter->value();
 }
 
 XlaCompiler::Options GenerateCompilerOptions(
-    const XlaCompilationCache& cache, OpKernelContext* ctx,
-    const XlaPlatformInfo& platform_info, bool has_ref_vars,
+    const XlaCompilationCache& cache,
+    const FunctionLibraryRuntime& function_library, DeviceBase* device,
+    se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
-  CHECK(ctx->function_library());
   XlaCompiler::Options options;
   options.client = static_cast<xla::LocalClient*>(cache.client());
-  if (ctx->op_device_context() != nullptr) {
-    options.device_ordinal =
-        ctx->op_device_context()->stream()->parent()->device_ordinal();
+  if (stream != nullptr) {
+    options.device_ordinal = stream->parent()->device_ordinal();
   }
   options.device_type = cache.device_type();
-  options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
-  options.graph_def_version = ctx->function_library()->graph_def_version();
+  options.flib_def = function_library.GetFunctionLibraryDefinition();
+  options.graph_def_version = function_library.graph_def_version();
   options.allow_cpu_custom_calls =
       (platform_info.platform_id() == se::host::kHostPlatformId);
   options.device_allocator =
-      GetAllocator(tf_allocator_adapter, ctx, platform_info);
+      GetAllocator(tf_allocator_adapter, device, stream, platform_info);
   if (platform_info.xla_device_metadata()) {
     options.shape_representation_fn =
         platform_info.xla_device_metadata()->shape_representation_fn();
diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h
index d58b32a..bfb438c 100644
--- a/tensorflow/compiler/jit/xla_platform_info.h
+++ b/tensorflow/compiler/jit/xla_platform_info.h
@@ -80,27 +80,31 @@
 };
 
 // Returns created XLA compilation cache.
-Status BuildXlaCompilationCache(OpKernelContext* ctx,
+Status BuildXlaCompilationCache(DeviceBase* dev,
                                 const XlaPlatformInfo& platform_info,
                                 XlaCompilationCache** cache);
 
 // Returns information about the platform from kernel context.
-XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
+XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
 
 // Returns allocator from platform info if non-null, or populate and return a
 // pointer to the allocator adapter with allocator from context.
 //
 // This is necessary because for XLA devices the underlying TF allocator returns
 // dummy tensors.
+//
+// `stream` parameter is nullable when running on host.
 se::DeviceMemoryAllocator* GetAllocator(
     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
-    OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
+    DeviceBase* device, se::Stream* stream,
+    const XlaPlatformInfo& platform_info);
 
 // Returns created options for the XLA compiler, and writes the used allocator
 // into `tf_allocator_adapter`.
 XlaCompiler::Options GenerateCompilerOptions(
-    const XlaCompilationCache& cache, OpKernelContext* ctx,
-    const XlaPlatformInfo& platform_info, bool has_ref_vars,
+    const XlaCompilationCache& cache,
+    const FunctionLibraryRuntime& function_library, DeviceBase* device,
+    se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 01c1877..5b79f78 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -25,10 +25,39 @@
 )
 
 cc_library(
+    name = "string_container_utils",
+    hdrs = ["utils/string_container_utils.h"],
+    deps = [
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "array_container_utils",
+    hdrs = ["utils/array_container_utils.h"],
+    deps = [
+        "@com_google_absl//absl/types:span",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "name_utils",
+    srcs = ["utils/name_utils.cc"],
+    hdrs = ["utils/name_utils.h"],
+    deps = [
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+    ],
+)
+
+cc_library(
     name = "op_or_arg_name_mapper",
     srcs = ["op_or_arg_name_mapper.cc"],
     hdrs = ["op_or_arg_name_mapper.h"],
     deps = [
+        ":name_utils",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
@@ -40,14 +69,14 @@
     srcs = ["tf_mlir_opt_main.cc"],
     deps = [
         ":init_mlir",
+        "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
         "//tensorflow/core:lib",
-        "//tensorflow/core/platform:logging",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
-        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:MlirOptLib",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Shape",
     ],
 )
 
@@ -64,7 +93,6 @@
         # xla-legalize-tf-with-tf2xla pass.
         "//tensorflow/compiler/jit",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
@@ -127,11 +155,8 @@
     deps = [
         ":passes",
         ":tf_mlir_opt_main",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
         "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration",
-        "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
         "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
-        "//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration",
         "//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing",
     ],
 )
@@ -141,7 +166,6 @@
     srcs = ["tf_mlir_translate_main.cc"],
     deps = [
         ":init_mlir",
-        "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
         "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
         "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
@@ -168,3 +192,5 @@
     name = "litfiles",
     srcs = glob(["runlit*py"]),
 )
+
+exports_files(["run_lit.sh"])
diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl
index edbf366..1fa57ba 100644
--- a/tensorflow/compiler/mlir/glob_lit_test.bzl
+++ b/tensorflow/compiler/mlir/glob_lit_test.bzl
@@ -43,10 +43,10 @@
               and specifying a default driver will abort the tests.
       features: [str], list of extra features to enable.
     """
-    if driver != _default_driver:
-        fail("There is no present support for custom drivers. Please omit" +
-             " the driver parameter when running this test. If you require" +
-             " custom driver support, please file an issue to request it.")
+
+    # Remove the default_driver from the data: it does not exist as a file and is
+    # just a placeholder from the copybara rewrite.
+    data = [d for d in data if d != _default_driver]
 
     # Disable tests on windows for now, to enable testing rest of all xla and mlir.
     native.py_test(
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 126d446..4014b88 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -341,6 +341,7 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
@@ -349,6 +350,22 @@
 )
 
 cc_library(
+    name = "mhlo_control_flow_to_scf",
+    srcs = ["lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc"],
+    hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
+    deps = [
+        ":hlo",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+cc_library(
     name = "map_lmhlo_to_scalar_op",
     hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"],
     deps = [
@@ -459,6 +476,7 @@
         ":lhlo",
         ":map_lmhlo_to_scalar_op",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgOps",
@@ -477,9 +495,11 @@
     deps = [
         ":lhlo",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
@@ -797,6 +817,7 @@
         ":lhlo_legalize_to_affine",
         ":lhlo_legalize_to_gpu",
         ":lhlo_legalize_to_parallel_loops",
+        ":mhlo_control_flow_to_scf",
         ":mhlo_fusion",
         ":mhlo_to_mhlo_lowering_patterns",
         ":sink_constants_to_control_flow",
@@ -813,7 +834,8 @@
     ],
     deps = [
         ":all_passes",
-        ":hlo_dialect_registration",
+        ":hlo",
+        ":lhlo",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index d0abbe0..12b9f5a 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1079,6 +1079,8 @@
   // XLA semantics is available. This limitation is because of the current XLA
   // implementation.
   let results = (outs I32Tensor);
+
+  let hasFolder = 1;
 }
 
 def HLO_MapOp: HLO_Op<"map",
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
index 3fa4658..750cce6 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td
@@ -81,6 +81,8 @@
     ElementsAttr:$value,
     Arg<LHLO_Buffer, "", [MemWrite]>:$output
   );
+
+  let hasCanonicalizer = 1;
 }
 
 def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h
index 5773901..90ff6c9 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h
@@ -17,10 +17,13 @@
 #define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
 
 namespace mlir {
+class DialectRegistry;
 namespace mhlo {
 
 void registerAllDialects();
 
+// Add chlo, mhlo, lmhlo dialects to the provided registry.
+void registerAllMhloDialects(DialectRegistry &registry);
 }
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
index fa3bde2..aa0f4c3 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
@@ -30,6 +30,11 @@
   let constructor = "createLegalizeControlFlowPass()";
 }
 
+def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
+  let summary = "Legalize from MHLO control flow to SCF control flow.";
+  let constructor = "createControlFlowToScfPass()";
+}
+
 def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
   let summary = "Legalizes gathers to a torch index select.";
   let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
index efa116f..541d8e4 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h
@@ -35,6 +35,9 @@
 /// Lowers HLO control flow ops to the Standard dialect.
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
 
+/// Lowers MHLO control flow ops to the SCF dialect.
+std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
+
 /// Lowers from HLO dialect to Standard dialect.
 std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index f5deb94..b0fa4ce 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -166,6 +166,20 @@
 }
 
 //===----------------------------------------------------------------------===//
+// GetDimensionSizeOp
+//===----------------------------------------------------------------------===//
+
+/// Fold get_dimension_size when the said shape dimension is a constant.
+OpFoldResult GetDimensionSizeOp::fold(ArrayRef<Attribute> attrs) {
+  RankedTensorType type = operand().getType().cast<RankedTensorType>();
+  int32_t dim = dimension().getSExtValue();
+  if (type.isDynamic(dim)) return {};
+  // The result type is always is a 0-d i32 tensor.
+  return DenseIntElementsAttr::get<int32_t>(
+      getResult().getType().cast<RankedTensorType>(), type.getDimSize(dim));
+}
+
+//===----------------------------------------------------------------------===//
 // IotaOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc
index 9fffeae..cf8bd25 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc
@@ -31,3 +31,11 @@
 
   // Dependent dialects
 }
+
+void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry &registry) {
+  // clang-format off
+  registry.insert<mlir::chlo::HloClientDialect,
+                  mlir::lmhlo::LmhloDialect,
+                  mlir::mhlo::MhloDialect>();
+  // clang-format on
+}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
index f61a663..81407c8 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc
@@ -29,6 +29,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
@@ -57,6 +58,38 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ConstOp.
+//===----------------------------------------------------------------------===//
+
+/// An lho.constant on an memref that is locally allocated and with no other
+/// users (other than dealloc's) can be erased.
+// TODO: This can be generalized to an arbitrary op by making use of memory
+// effects (write memory effect).
+struct EraseConstOp : public OpRewritePattern<ConstOp> {
+  using OpRewritePattern<ConstOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConstOp op,
+                                PatternRewriter& rewriter) const override {
+    Value memref = op.output();
+    if (!memref.getDefiningOp<AllocOp>()) {
+      return failure();
+    }
+
+    // Check that all uses of the memref are either DeallocOps or this op.
+    for (Operation* user : memref.getUsers())
+      if (user != op && !isa<DeallocOp>(user)) return failure();
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
+                                          MLIRContext* context) {
+  results.insert<EraseConstOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
 // StaticMemRefCastOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
index bb9f98d..945fa0e 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt
@@ -93,6 +93,7 @@
 add_mlir_library(MhloToStandard
   legalize_control_flow.cc
   legalize_to_standard.cc
+  mhlo_control_flow_to_scf.cc
 
   DEPENDS
   MLIRhlo_opsIncGen
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc
index 50cd6df..263b6cd 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc
@@ -29,6 +29,10 @@
 
 struct TestChloLegalizeToHloPass
     : public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
+  }
+
   void runOnFunction() override {
     ConversionTarget conversionTarget(getContext());
     OwningRewritePatternList conversionPatterns;
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
index a8c3ad1..a784c05 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
@@ -388,6 +388,10 @@
 
 struct HloLegalizeToLhlo
     : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<lmhlo::LmhloDialect>();
+  }
+
  public:
   HloLegalizeToLhlo() = default;
   HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index f47f2c2..ff780fd 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -15,6 +15,8 @@
 
 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
 
+#include <numeric>
+
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
@@ -598,6 +600,7 @@
     unsigned currSrcDim = 0, currDstDim = 0;
     SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
         dstShape.size());
+    bool isExpandingOrCollapsing = true;
     while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
       int64_t dstSize = dstShape[currDstDim];
       int64_t srcSize = srcShape[currSrcDim];
@@ -619,11 +622,47 @@
           }
         }
       } else {
-        return failure();
+        isExpandingOrCollapsing = false;
+        break;
       }
       currDstDim++;
     }
-    if (currSrcDim != srcShape.size()) return failure();
+    if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false;
+
+    if (!isExpandingOrCollapsing) {
+      auto getIdentityExprs = [&rewriter](int n) {
+        SmallVector<AffineExpr, 4> exprs;
+        for (int i = 0; i < n; ++i)
+          exprs.push_back(rewriter.getAffineDimExpr(i));
+        return exprs;
+      };
+      Location loc = reshapeOp.getLoc();
+      int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1,
+                                           std::multiplies<int64_t>());
+      auto elemType = operandType.getElementType();
+      SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
+          getIdentityExprs(dstShape.size())};
+      SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
+          getIdentityExprs(srcShape.size())};
+
+      if (isLHLO) {
+        auto collapsedType = MemRefType::get({totalElems}, elemType);
+        Value collapsedOp = rewriter.create<linalg::ReshapeOp>(
+            loc, collapsedType, args[0], collapsingMap);
+        Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
+            loc, resultType, collapsedOp, expandingMap);
+        rewriter.replaceOpWithNewOp<linalg::CopyOp>(
+            reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
+            /*outputPermutation =*/nullptr);
+      } else {
+        auto collapsedType = RankedTensorType::get({totalElems}, elemType);
+        Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
+            loc, collapsedType, args[0], collapsingMap);
+        rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
+            reshapeOp, resultType, collapsedOp, expandingMap);
+      }
+      return success();
+    }
 
     if (isLHLO) {
       Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
@@ -827,6 +866,10 @@
 // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 struct LhloLegalizeToLinalgPass
     : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect>();
+  }
+
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     ConversionTarget target(getContext());
@@ -843,6 +886,10 @@
 
 struct HloLegalizeToLinalgPass
     : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<linalg::LinalgDialect>();
+  }
+
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     ConversionTarget target(getContext());
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
index cc574e0..5000fce 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc
@@ -178,6 +178,10 @@
 namespace {
 struct LegalizeToStandardPass
     : public PassWrapper<LegalizeToStandardPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<StandardOpsDialect>();
+  }
+
   /// Perform the lowering to Standard dialect.
   void runOnFunction() override;
 };
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc
index 1467f01..6dc5b64 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc
@@ -19,8 +19,10 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -33,6 +35,10 @@
 
 class LhloFuseLinalgPass
     : public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
+  }
+
  public:
   LhloFuseLinalgPass() = default;
   LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc
index 0789132..2771afc 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc
@@ -139,6 +139,9 @@
 
 struct LhloLegalizeToAffinePass
     : public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<AffineDialect>();
+  }
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     auto func = getFunction();
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc
index cffb58b..fbade8f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc
@@ -20,8 +20,10 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
@@ -169,6 +171,11 @@
 
 struct LhloLegalizeToGpuPass
     : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
+                    scf::SCFDialect>();
+  }
+
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     ConversionTarget target(getContext());
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc
index 8493a1f..3d49027 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc
@@ -29,6 +29,10 @@
 class TestLhloToLLVMPass
     : public ::mlir::PassWrapper<TestLhloToLLVMPass,
                                  ::mlir::OperationPass<::mlir::ModuleOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect>();
+  }
+
  public:
   void runOnOperation() override {
     ModuleOp m = getOperation();
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc
index 19f47d0..d9a2d99 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc
@@ -691,6 +691,10 @@
 
 struct LhloLegalizeToParallelLoopsPass
     : public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<StandardOpsDialect, scf::SCFDialect>();
+  }
+
   void runOnFunction() override {
     auto func = getFunction();
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc
new file mode 100644
index 0000000..aba7b07
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc
@@ -0,0 +1,199 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "llvm/Support/Casting.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+#define DEBUG_TYPE "mhlo-control-flow-to-scf"
+
+namespace mlir {
+namespace mhlo {
+
+namespace {
+
+/// Convert MHLO While to SCF.
+void MatchAndRewrite(WhileOp whileOp);
+
+/// Pass that converts MHLO control flow to SCF.
+class ControlFlowToScfPass
+    : public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<scf::SCFDialect>();
+  }
+  void runOnFunction() override {
+    getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
+  }
+};
+
+// TODO(jpienaar): Look into reformulating as a pattern.
+void MatchAndRewrite(WhileOp whileOp) {
+  // Handle pattern:
+  //   x = start
+  //   step = ...
+  //   limit = ...
+  //   while (x < limit) { ... x += step; }
+
+  // Only handling multi value while loops at the moment.
+  auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
+  if (!tupleOp) return;
+  auto bodyReturn = whileOp.body()
+                        .front()
+                        .getTerminator()
+                        ->getOperand(0)
+                        .getDefiningOp<mhlo::TupleOp>();
+  // Note: due to the shape restrictions on While, if the operand to While is a
+  // tuple, then so is the return type of the body. But the verifier isn't
+  // checking that at the moment, so just bail out here if this doesn't hold.
+  if (!bodyReturn) return;
+
+  Value result = whileOp.cond().front().getTerminator()->getOperand(0);
+  // TODO(jpienaar): Expand to handle more than simple case with LT compare and
+  // constant step.
+  auto cmp = result.getDefiningOp<mhlo::CompareOp>();
+  if (!cmp || cmp.comparison_direction() != "LT") return;
+
+  const int kConstant = -1;
+  auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
+    if (matchPattern(val, m_Constant())) return {val, kConstant};
+    // If it is defined by a tuple, then the tuple has to have been fed in and
+    // the external value is captured.
+    if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
+      if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
+      int index = gte.index().getSExtValue();
+      return {tupleOp.getOperand(index), index};
+    }
+    return {nullptr, 0};
+  };
+
+  using ValueIndex = std::pair<Value, int>;
+  ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
+  ValueIndex max = getValueAndIndex(cmp.rhs());
+  if (!loopIndVar.first || !max.first) return;
+  auto add =
+      bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
+  if (!add) return;
+  ValueIndex step = getValueAndIndex(add.rhs());
+  if (step.second != kConstant || !step.first) return;
+
+  // Only handle case where tuple isn't propagated as is for now.
+  // TODO(jpienaar): Remove this when a tuple is also created inside the loop
+  // to propagate.
+  for (auto* use : whileOp.body().front().getArgument(0).getUsers())
+    if (!isa<GetTupleElementOp>(use)) return;
+
+  LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
+             llvm::dbgs() << "  loopIndVar = " << loopIndVar.second << " max = "
+                          << max.second << " step = " << step.second << "\n";
+             llvm::dbgs() << "  loopIndVar = " << loopIndVar.first << " max = "
+                          << max.first << " step = " << step.first << "\n";);
+  OpBuilder b(whileOp);
+  // Inputs to new for loop.
+  llvm::SmallVector<Value, 4> input;
+  input.reserve(tupleOp.getNumOperands());
+  for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
+    input.push_back(r);
+  for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
+    input.push_back(r);
+
+  auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
+  auto getAsIndex = [&](Value val) {
+    auto loc = whileOp.getLoc();
+    return b.create<ExtractElementOp>(
+        loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
+  };
+
+  // SCF for uses index type, so converted these.
+  auto forloopIndVar = getAsIndex(loopIndVar.first);
+  auto forMax = getAsIndex(max.first);
+  auto forStep = getAsIndex(step.first);
+  auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
+                                          forMax, forStep, input);
+  // Transfer the body without the block arguments.
+  forOp.getLoopBody().front().getOperations().splice(
+      forOp.getLoopBody().front().getOperations().end(),
+      whileOp.body().front().getOperations());
+
+  b.setInsertionPointToStart(&forOp.getLoopBody().front());
+  auto loopIndVarElType =
+      loopIndVar.first.getType().cast<ShapedType>().getElementType();
+  Value indVar = b.create<SplatOp>(
+      whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
+      b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
+                            forOp.getInductionVar()));
+  // Update all block argument users to the SCF For args.
+  for (auto* use :
+       llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
+    // TODO(jpienaar): Expand here too when we allow using the tuple in the
+    // loop.
+    auto gte = cast<GetTupleElementOp>(use);
+    // If the loop induction var, then refer to the loop induction variable as
+    // this operand is not updated.
+    if (gte.index() == loopIndVar.second) {
+      use->getResult(0).replaceAllUsesWith(indVar);
+      use->erase();
+      continue;
+    }
+    int index = gte.index().getSExtValue();
+    // If after the loop induction variable, then decrement as we don't include
+    // the loop induction variable in the for iter operands.
+    if (index > loopIndVar.second) --index;
+    use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
+    use->erase();
+  }
+
+  // Create new yield op without induction var update.
+  SmallVector<Value, 4> newYieldOps;
+  newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
+  for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
+    newYieldOps.push_back(r);
+  for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
+    newYieldOps.push_back(r);
+  // Delete return & tuple op.
+  forOp.getLoopBody().front().back().erase();
+  forOp.getLoopBody().front().back().erase();
+  b.setInsertionPointToEnd(&forOp.getLoopBody().front());
+  b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
+
+  // Recombine output tuple with max value of induction variable.
+  llvm::SmallVector<Value, 4> loopOut;
+  loopOut.reserve(forOp.getNumResults() + 1);
+  for (auto r : forOp.getResults().take_front(loopIndVar.second))
+    loopOut.push_back(r);
+  loopOut.push_back(max.first);
+  for (auto r : forOp.getResults().drop_front(loopIndVar.second))
+    loopOut.push_back(r);
+  b.setInsertionPoint(whileOp);
+  auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
+  whileOp.replaceAllUsesWith(newRes.getOperation());
+  whileOp.erase();
+}
+
+}  // anonymous namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
+  return std::make_unique<ControlFlowToScfPass>();
+}
+
+}  // namespace mhlo
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
index 7c985ea..58a0f0e 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
@@ -153,6 +153,10 @@
 
 struct TransformUnrankedHloPass
     : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<shape::ShapeDialect>();
+  }
+
   void runOnFunction() override {
     // Setup conversion target.
     MLIRContext &ctx = getContext();
diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
index 15b1a15..0b6cc1c 100644
--- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
@@ -597,3 +597,32 @@
   // CHECK: return [[ARG0]]
   return %3 : tuple<tensor<i32>>
 }
+
+// CHECK-LABEL: func @erase_dead_lhlo_constant
+func @erase_dead_lhlo_constant() {
+  %M = alloc() : memref<256x1024xf32>
+  // CHECK-NEXT: return
+  "lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
+  dealloc %M : memref<256x1024xf32>
+  return
+}
+
+// A negative test for dead lhlo constant op erasure.
+// CHECK-LABEL: func @erase_dead_lhlo_constant_negative
+func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> {
+  // CHECK-NEXT: lmhlo.constant
+  "lmhlo.constant"(%M) {value = dense<0.0> : tensor<f32>} : (memref<4xf32>) -> ()
+  // CHECK-NEXT: alloc
+  // CHECK-NEXT: lmhlo.constant
+  %N = alloc() : memref<256x1024xf32>
+  "lmhlo.constant"(%N) {value = dense<0.0> : tensor<f32>} : (memref<256x1024xf32>) -> ()
+  return %N : memref<256x1024xf32>
+}
+
+// CHECK-LABEL: func @fold_get_dimension_size
+func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor<i32> {
+  %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor<i32>
+  return %size : tensor<i32>
+  // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor<i32>
+  // CHECK-NEXT: return %[[C]]
+}
diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
index 46725e0..aecf612 100644
--- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir
@@ -373,6 +373,18 @@
 
 // -----
 
+// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func @reshape_3D_4D
+func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> {
+  %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32>
+  return %0 : tensor<1x784x1x1xf32>
+}
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]]
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]]
+
+// -----
+
 // CHECK-LABEL: func @minf
 func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
   %0 = "mhlo.minimum"(%lhs, %rhs)
diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir
new file mode 100644
index 0000000..9c887a7
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s
+
+func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<4xf32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<i32>) -> (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) {
+  %cst = constant dense<-1> : tensor<i32>
+  %cst_0 = constant dense<1> : tensor<i32>
+  %cst_1 = constant dense<0> : tensor<i32>
+  %cst_2 = constant dense<1000> : tensor<i32>
+  %0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+  %1 = "mhlo.while"(%0) ( {
+  ^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):  // no predecessors
+    %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
+    %3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
+    %4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%4) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):  // no predecessors
+    %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
+    %3 = mhlo.add %2, %cst_0 : tensor<i32>
+    %4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
+    %5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
+    %6 = "mhlo.tuple"(%3, %4, %5) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+    "mhlo.return"(%6) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> ()
+  }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+  return %1 : tuple<tensor<i32>, tensor<i32>, tensor<i32>>
+}
+
+// CHECK-LABEL:   func @lt_loop(
+// CHECK:  %[[VAL_9:.*]] = constant dense<-1> : tensor<i32>
+// CHECK:  %[[VAL_10:.*]] = constant dense<1> : tensor<i32>
+// CHECK:  %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
+// CHECK:  %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
+// CHECK:  %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
+// CHECK:  %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
+// CHECK:  %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
+// CHECK:  %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
+// CHECK:  %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
+// CHECK:  %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
+// CHECK:  scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
index 768d8da..f174b00 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
@@ -688,6 +688,20 @@
 
 // -----
 
+// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func @reshape_3D_4D
+func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) {
+  "lmhlo.reshape"(%arg0, %arg1)
+   : (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> ()
+  return
+}
+// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]]
+// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]]
+// CHECK: linalg.copy
+
+// -----
+
 // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
 // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @reverse
diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
index c071e2c..d0c0e3c 100644
--- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
+++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp
@@ -13,112 +13,25 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "mlir-hlo/Dialect/mhlo/IR/register.h"
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
-#include "mlir/IR/AsmState.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/MLIRContext.h"
 #include "mlir/InitAllDialects.h"
 #include "mlir/InitAllPasses.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/MlirOptMain.h"
 
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
-                                                llvm::cl::desc("<input file>"),
-                                                llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> outputFilename(
-    "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
-    llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> splitInputFile(
-    "split-input-file",
-    llvm::cl::desc("Split the input file into pieces and process each "
-                   "chunk independently"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> verifyDiagnostics(
-    "verify-diagnostics",
-    llvm::cl::desc("Check that emitted diagnostics match "
-                   "expected-* lines on the corresponding line"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> verifyPasses(
-    "verify-each",
-    llvm::cl::desc("Run the verifier after each transformation pass"),
-    llvm::cl::init(true));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> allowUnregisteredDialects(
-    "allow-unregistered-dialect",
-    llvm::cl::desc("Allow operation with no registered dialects"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> showDialects(
-    "show-dialects", llvm::cl::desc("Print the list of registered dialects"),
-    llvm::cl::init(false));
-
 int main(int argc, char **argv) {
-  mlir::registerAllDialects();
   mlir::registerAllPasses();
-
-  mlir::mhlo::registerAllDialects();
   mlir::mhlo::registerAllMhloPasses();
   mlir::lmhlo::registerAllLmhloPasses();
 
-  llvm::InitLLVM y(argc, argv);
+  mlir::DialectRegistry registry;
+  mlir::registerAllDialects(registry);
+  registry.insert<mlir::mhlo::MhloDialect>();
+  registry.insert<mlir::chlo::HloClientDialect>();
+  registry.insert<mlir::lmhlo::LmhloDialect>();
 
-  // Register any pass manager command line options.
-  mlir::registerAsmPrinterCLOptions();
-  mlir::registerMLIRContextCLOptions();
-  mlir::registerPassManagerCLOptions();
-  mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
-
-  // Parse pass names in main to ensure static initialization completed.
-  llvm::cl::ParseCommandLineOptions(argc, argv,
-                                    "MLIR modular optimizer driver\n");
-
-  if (showDialects) {
-    mlir::MLIRContext context;
-    llvm::outs() << "Registered Dialects:\n";
-    for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
-      llvm::outs() << dialect->getNamespace() << "\n";
-    }
-    return 0;
-  }
-
-  // Set up the input file.
-  std::string errorMessage;
-  auto file = mlir::openInputFile(inputFilename, &errorMessage);
-  if (!file) {
-    llvm::errs() << errorMessage << "\n";
-    return 1;
-  }
-
-  auto output = mlir::openOutputFile(outputFilename, &errorMessage);
-  if (!output) {
-    llvm::errs() << errorMessage << "\n";
-    exit(1);
-  }
-
-  if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
-                         splitInputFile, verifyDiagnostics, verifyPasses,
-                         allowUnregisteredDialects))) {
-    return 1;
-  }
-  // Keep the output file if the invocation of MlirOptMain was successful.
-  output->keep();
-  return 0;
+  return failed(
+      mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));
 }
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 0a93b96..0231441 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -281,6 +281,28 @@
 )
 
 cc_library(
+    name = "nms_utils",
+    srcs = [
+        "utils/nms_utils.cc",
+    ],
+    hdrs = [
+        "utils/nms_utils.h",
+    ],
+    copts = ["-std=c++14"],
+    deps = [
+        ":tensorflow_lite",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
+        "//tensorflow/core:framework",
+        "@flatbuffers",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+    ],
+)
+
+cc_library(
     name = "tftext_utils",
     srcs = [
         "utils/tftext_utils.cc",
@@ -373,6 +395,7 @@
     deps = [
         ":constant_utils",
         ":lstm_utils",
+        ":nms_utils",
         ":stateful_ops_utils",
         ":tensorflow_lite",
         ":tftext_utils",
@@ -509,19 +532,6 @@
     ],
 )
 
-# Library with tensorflow Lite dialect static initialization.
-cc_library(
-    name = "tensorflow_lite_dialect_registration",
-    srcs = [
-        "ir/dialect_registration.cc",
-    ],
-    deps = [
-        ":tensorflow_lite",
-        "@llvm-project//mlir:IR",
-    ],
-    alwayslink = 1,
-)
-
 tf_native_cc_binary(
     name = "converter-gen",
     srcs = [
@@ -628,7 +638,6 @@
         ":flatbuffer_tflite_operator_lib",
         ":stateful_ops_utils",
         ":tensorflow_lite",
-        ":tensorflow_lite_dialect_registration",
         "//tensorflow/compiler/mlir:op_or_arg_name_mapper",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:convert_tensor",
@@ -672,7 +681,7 @@
         ":convert_type",
         ":flatbuffer_tflite_operator_lib",
         ":tensorflow_lite",
-        ":tensorflow_lite_dialect_registration",
+        "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:mangling_util",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
         "//tensorflow/compiler/xla:statusor",
@@ -740,16 +749,13 @@
     ],
     deps = [
         ":flatbuffer_translate_lib",
+        ":tensorflow_lite",
+        "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
-        "@com_google_absl//absl/base",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/container:flat_hash_set",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:MlirTranslateMain",
         "@llvm-project//mlir:QuantOps",
-        "@llvm-project//mlir:SCFTransforms",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Translation",
@@ -762,7 +768,7 @@
     deps = [
         ":flatbuffer_translate_registeration",
         # TODO(b/155809683): Link only necessary dialects.
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
     ],
 )
 
@@ -814,7 +820,7 @@
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
         # TODO(b/155809683): Link only necessary dialects.
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
@@ -838,19 +844,18 @@
     deps = [
         ":flatbuffer_translate_lib",
         ":flatbuffer_translate_registeration",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//llvm:Support",
-        # TODO(b/155809683): Link only necessary dialects.
-        "@llvm-project//mlir:AllPassesAndDialects",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Support",
-        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+        ":tensorflow_lite",
+        "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/core:lib",
         "//tensorflow/core/platform:logging",
         "//tensorflow/lite:framework",
         "//tensorflow/lite/delegates/flex:delegate",
         "//tensorflow/lite/kernels:builtin_ops",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:StandardOps",
     ],
 )
 
@@ -877,7 +882,7 @@
         "//tensorflow/compiler/mlir/tensorflow:translate_lib",
         "//tensorflow/core:core_cpu_base",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Transforms",
@@ -911,7 +916,7 @@
         "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
index c3a0800..34200fb 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc
@@ -61,6 +61,7 @@
 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
@@ -354,8 +355,13 @@
     if (emit_custom_ops) {
       enabled_op_types_.emplace(OpType::kCustomOp);
     }
-    tf_dialect_ = module.getContext()->getRegisteredDialect("tf");
-    tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl");
+    tf_dialect_ =
+        module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
+    tfl_dialect_ = module.getContext()
+                       ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
+    // Right now the TF executor dialect is still needed to build NodeDef.
+    module.getContext()
+        ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
   }
 
   Optional<std::string> TranslateInternal();
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index c46c4a7..2303837 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -65,6 +65,7 @@
 #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -479,7 +480,7 @@
 
     value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
   } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
-    auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
+    auto dialect = elem_type.getContext()->getLoadedDialect("tf");
     tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
     std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
 
@@ -1072,6 +1073,10 @@
     const std::vector<std::string>& ordered_input_arrays,
     const std::vector<std::string>& ordered_output_arrays,
     bool experimental_prune_unreachable_nodes_unconditionally) {
+  context->loadDialect<
+      mlir::StandardOpsDialect, mlir::quant::QuantizationDialect,
+      mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>();
+
   auto model_ptr =
       FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
   if (nullptr == model_ptr) {
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
index 3a47d07..5accb41 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
@@ -249,7 +249,7 @@
       {static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
   attributes->emplace_back(builder.getNamedAttr(
       "custom_option",
-      OpaqueElementsAttr::get(builder.getContext()->getRegisteredDialect("tfl"),
+      OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"),
                               type, content)));
 
   return Status::OK();
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
index 5b95b30..94f7e22 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
@@ -17,6 +17,7 @@
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/ToolOutputFile.h"
 #include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
@@ -33,6 +34,8 @@
 #include "mlir/Translation.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
 
 using llvm::cl::opt;
@@ -175,5 +178,11 @@
     });
 
 static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate(
-    "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction);
+    "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction,
+    [](DialectRegistry& registry) {
+      registry.insert<quant::QuantizationDialect>();
+      registry.insert<TF::TensorFlowDialect>();
+      registry.insert<TFL::TensorFlowLiteDialect>();
+      registry.insert<StandardOpsDialect>();
+    });
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/ir/dialect_registration.cc b/tensorflow/compiler/mlir/lite/ir/dialect_registration.cc
deleted file mode 100644
index fae2043..0000000
--- a/tensorflow/compiler/mlir/lite/ir/dialect_registration.cc
+++ /dev/null
@@ -1,19 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
-
-// Static initialization for TensorFlow Lite op registration.
-static mlir::DialectRegistration<mlir::TFL::TensorFlowLiteDialect> tfl_ops;
diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
index 0d42fbb..35a58a0 100644
--- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
+++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
@@ -30,12 +30,16 @@
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/SMLoc.h"
 #include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Function.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/core/platform/init_main.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/lite/delegates/flex/delegate.h"
@@ -98,6 +102,10 @@
 
   // Load the MLIR module.
   mlir::MLIRContext context;
+  context.getDialectRegistry()
+      .insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
+              mlir::StandardOpsDialect>();
+
   llvm::SourceMgr source_mgr;
   source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
   mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
index 6299a70..7e7d467 100644
--- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
@@ -62,6 +62,10 @@
 
   void runOnFunction() override;
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<quant::QuantizationDialect>();
+  }
+
   // Parses the serialized quant stats protobuf and initialize the internal
   // data structure. This method must be called after the pass is created.
   bool ParseQuantStats(const std::string &stats_str);
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
index 31c0e4c..38c7ad8 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
@@ -28,6 +28,7 @@
     deps = [
         "//tensorflow/compiler/mlir/lite:common",
         "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
         "//tensorflow/compiler/mlir/tensorflow:error_util",
@@ -74,6 +75,6 @@
         "//tensorflow/lite/schema:schema_fbs",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
     ],
 )
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
index a2e3c06..238710b 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
@@ -52,6 +53,7 @@
   }
 
   MLIRContext context;
+  context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
   StatusScopedDiagnosticHandler statusHandler(&context,
                                               /*propagate=*/true);
 
diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt
index f482e3d..a7f6040 100644
--- a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
+# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=: -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
 node {
   name: "tf.Less"
   op: "Less"
diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir
index f6f32e7..138614d 100644
--- a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir
@@ -3435,4 +3435,19 @@
 }
 // CHECK:  func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
 // CHECK:    %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
-// CHECK:    return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
\ No newline at end of file
+// CHECK:    return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
+
+
+func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
+  %0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64>
+  %1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64>
+  %2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64>
+  %3 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
+  %4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<1xi64>) -> tensor<?x10xf64>
+  return %4 : tensor<?x10xf64>
+}
+
+
+// CHECK: func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
+// CHECK:   %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = opaque<"tfl", "0x686173685F736565640000000A00000071F86A71318B0AA8023F331CD59AC14AC5E7E95CDE35AD68F474A4711A3C5CC2421F5B20AE52EB1F6275636B6574730002094200030000000100000002000000FFFFFF7F44000000062E0A2601"> : tensor<93xi8>} : (tensor<?x!tf.string>, tensor<?xi64>) -> tensor<?x10xf64>
+// CHECK:   return %0 : tensor<?x10xf64>
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index d02e4e7..ba98e76 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -1029,14 +1029,49 @@
   // CHECK: "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
 }
 
-func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
+func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
+  %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} :
+(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
+  return %0 : tensor<40x40xf32>
+// CHECK-LABEL: matmul
+// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
+// CHECK: %[[CST_0:.*]] = constant unit
+// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
+}
+
+func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
+  %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} :
+(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
+  return %0 : tensor<40x40xf32>
+// CHECK-LABEL: matmul_transposed_a
+// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
+// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
+// CHECK: %[[CST_2:.*]] = constant unit
+// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
+}
+
+func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
   %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} :
 (tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
   return %0 : tensor<40x40xf32>
-// CHECK-LABEL: matmul_transposed
+// CHECK-LABEL: matmul_transposed_b
 // CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
 }
 
+func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
+  %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} :
+(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
+  return %0 : tensor<40x40xf32>
+// CHECK-LABEL: matmul_transposed_ab
+// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
+// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
+// CHECK: %[[CST_1:.*]] = constant unit
+// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
+}
+
 func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
   %0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
   %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir
deleted file mode 100644
index 7e9f66b..0000000
--- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s
-
-func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
-^bb0(%arg0: tensor<3x2xi32>):
-  // CHECK: error: 'unknown_op' op dialect is not registered
-  %0 = "unknown_op"(%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
-  return %0 : tensor<3x2xi32>
-}
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index edbcef3..10ff03a 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -1175,3 +1175,29 @@
 // CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32>
 // CHECK: return %[[RESULT]] : tensor<1x1xf32>
 }
+
+func @SoftMaxWithNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> {
+  %cst = constant dense<1> : tensor<1xi32>
+  %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32>
+  %1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32>
+  %2 = "tfl.exp"(%1) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+  %3 = "tfl.sum"(%2, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32>
+  %4 = "tfl.div"(%2, %3) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32>
+  return %4 : tensor<8x128xf32>
+
+// CHECK-LABEL: SoftMaxWithNormalization
+// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32>
+// CHECK: return %[[RESULT]] : tensor<8x128xf32>
+}
+
+func @SoftMaxWithoutNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> {
+  %cst = constant dense<1> : tensor<1xi32>
+  %0 = "tfl.exp"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32>
+  %1 = "tfl.sum"(%0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32>
+  %2 = "tfl.div"(%0, %1) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32>
+  return %2 : tensor<8x128xf32>
+
+// CHECK-LABEL: SoftMaxWithoutNormalization
+// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32>
+// CHECK: return %[[RESULT]] : tensor<8x128xf32>
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
index 9e8a957..2b87176 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir
@@ -520,3 +520,42 @@
   return %0 : tensor<100xf32>
   }
 }
+
+// -----
+
+module {
+func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
+  %0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
+}
+
+// CHECK-LABEL: func @tflite_custom_nms(
+// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<1x100x4xf32>,
+// CHECK-SAME:                          %[[VAL_1:.*]]: tensor<1x100x91xf32>,
+// CHECK-SAME:                          %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} {
+// CHECK:         %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {custom_code = "TFLite_Detection_PostProcess", custom_option = opaque<"tfl", "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F010000000A0000009A99193F0000003F5B0000000000000000000040000020410000A0400E06060E0E06060E0E0E322601"> : tensor<217xi8>} : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
+// CHECK:         return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
+// CHECK:       }
+}
+
+// -----
+
+module {
+// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}}
+func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
+
+// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}}
+func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
+
+// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}}
+func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
+  %0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
+}
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 297b145..bc460db 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -30,6 +30,7 @@
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/Threading.h"
 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
@@ -65,7 +66,6 @@
 // The actual LegalizeTF Pass.
 namespace {
 
-using xla::Status;
 using xla::StatusOr;
 
 constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
@@ -74,6 +74,10 @@
 
 // Legalize operations in functions.
 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<quant::QuantizationDialect, TFL::TensorFlowLiteDialect>();
+  }
+
  public:
   LegalizeTF() = default;
   LegalizeTF(const LegalizeTF&) {}
@@ -227,26 +231,47 @@
   return success();
 }
 
-// The following is effectively:
-// def : Pat<
-//   (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a,
-//      ConstBoolAttrTrue:$transpose_b),
-//   (TFL_FullyConnectedOp:$__0 $a, $b,
-//     NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tf_matmul_op = cast<TF::MatMulOp>(op);
-  if (tf_matmul_op.transpose_a()) return failure();
-  if (!tf_matmul_op.transpose_b()) return failure();
+  auto lhs = op->getOperand(0);
+  auto rhs = op->getOperand(1);
+  auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
+    RankedTensorType type =
+        input.getType().dyn_cast_or_null<RankedTensorType>();
+    if (!type || type.getRank() != 2) return {failure(), nullptr};
+
+    auto permute_attr = DenseIntElementsAttr::get(
+        RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
+    auto permute = rewriter.create<ConstantOp>(
+        op->getLoc(), permute_attr.getType(), permute_attr);
+    llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
+                                            type.getShape()[0]};
+    auto output = rewriter.create<TFL::TransposeOp>(
+        op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
+        input, permute);
+    return {success(), output};
+  };
+
+  // TODO(jpienaar): Remove once handled via dailect conversion.
+  if (tf_matmul_op.transpose_a()) {
+    LogicalResult result = success();
+    std::tie(result, lhs) = transpose(lhs);
+    if (failed(result)) return failure();
+  }
+  if (!tf_matmul_op.transpose_b()) {
+    LogicalResult result = success();
+    std::tie(result, rhs) = transpose(rhs);
+    if (failed(result)) return failure();
+  }
 
   Type output_type = tf_matmul_op.getResult().getType();
-  // TODO(jpienaar): Follow up post shuffle discussion.
   auto no_input = rewriter.create<ConstantOp>(
       op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
   auto fc_op = rewriter.create<FullyConnectedOp>(
-      op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0),
-      op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
-      rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
+      op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
+      rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
+      rewriter.getBoolAttr(false));
   rewriter.replaceOp(op, {fc_op.getResult(0)});
   return success();
 }
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc
index 6202507..8b54ca4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc
@@ -33,6 +33,10 @@
 // cond and body regions.
 struct LegalizeWhile
     : public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<TFL::TensorFlowLiteDialect>();
+  }
+
   void RunOnFunction(FuncOp func);
 
   void runOnOperation() override {
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 559d22d..e53024b 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -552,3 +552,37 @@
      (HasOneUse $reduce)]>;
 }
 
+
+def IsSame : Constraint<CPred<"$0 == $1">>;
+def HasTwoUse : Constraint<CPred<
+  "std::distance($0.use_begin(), $0.use_end()) == 2">>;
+def AxesIsLastDimension : Constraint<CPred<
+  "$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && "
+  "$0.cast<DenseIntElementsAttr>().getValue<APInt>({0}) == "
+  "$1.getType().cast<ShapedType>().getRank() - 1">>;
+
+// Convert exp(x)/sum(exp(x)) into softmax.
+def OptimizeToSoftmax : Pat<
+  (TFL_DivOp (TFL_ExpOp:$exp $input),
+             (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes),
+                             ConstBoolAttrTrue), TFL_AF_None),
+  (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">),
+  [(IsSame $exp, $sum_input),
+   (AxesIsLastDimension $axes, $sum_input),
+   (HasTwoUse $exp),
+   (HasOneUse $sum)]>;
+
+// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals
+// with the max normalization.
+def FoldNormalizationIntoSoftmax : Pat<
+  (TFL_SoftmaxOp
+    (TFL_SubOp:$sub $input,
+      (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes),
+                            ConstBoolAttrTrue),
+    TFL_AF_None),
+    $beta),
+  (TFL_SoftmaxOp $input, $beta),
+  [(IsSame $input, $max_input),
+   (AxesIsLastDimension $axes, $max_input),
+   (HasOneUse $sub),
+   (HasOneUse $max)]>;
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
index 0efd718..172ce59d 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
@@ -42,6 +42,7 @@
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
+#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -59,6 +60,7 @@
 
 constexpr char kTFAPIImplements[] = "tf.api_implements";
 constexpr char kTFTextAPIPrefix[] = "tftext:";
+constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
 
 using mlir::TF::FuncAttr;
@@ -99,59 +101,6 @@
   FuncOp func_;
 };
 
-// Abstracts the conversion of the padded NMS composite function.
-class ConvertNMSPaddedFunc {
- public:
-  explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
-
-  void RewriteFunc() {
-    func_.setAttr(kTFImplements,
-                  StringAttr::get(kTfNMSPadded, func_.getContext()));
-    Value boxes = func_.getArgument(0);
-    Value scores = func_.getArgument(1);
-    Value max_output_size = func_.getArgument(2);
-    Value iou_threshold = func_.getArgument(3);
-    Value score_threshold = func_.getArgument(4);
-    auto output_type0 = func_.getType().getResult(0);
-    auto output_type1 = func_.getType().getResult(1);
-
-    OpBuilder builder(func_.getBody());
-    auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
-        func_.getLoc(), output_type0, output_type1, boxes, scores,
-        max_output_size, iou_threshold, score_threshold);
-
-    builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
-  }
-
-  LogicalResult VerifySignature() {
-    // Verify high-level function signature.
-    // Relevant argument characteristics are checked by the TFL op definition.
-    if (func_.getNumArguments() < 5) {
-      return func_.emitError()
-             << "Invalid number of arguments to "
-                "non_max_suppression_padded_v2 (need atleast 5): "
-             << func_.getNumArguments();
-    }
-    if (func_.getType().getNumResults() != 2) {
-      return func_.emitError() << "Invalid number of results from "
-                                  "non_max_suppression_padded_v2 (need 2): "
-                               << func_.getType().getNumResults();
-    }
-    // The TFLite fused op does not support batching yet.
-    // TODO(b/158709815): Add support for batches with padded NMS.
-    auto boxes_type =
-        func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
-    if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
-      return func_.emitError() << "TFLite does not support batched input for "
-                                  "non_max_suppression_padded";
-    }
-    return success();
-  }
-
- private:
-  FuncOp func_;
-};
-
 // This pass uses mechanisms listed in RFC:
 // https://github.com/tensorflow/community/pull/113
 // It prepares composite functions that are attributed to indicate
@@ -161,6 +110,10 @@
 class PrepareCompositeFunctionsPass
     : public PassWrapper<PrepareCompositeFunctionsPass,
                          OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<TFL::TensorFlowLiteDialect>();
+  }
+
  public:
   explicit PrepareCompositeFunctionsPass() {}
 
@@ -219,6 +172,12 @@
     if (failed(ConvertTFTextAPI(func, api_name, attr))) {
       return signalPassFailure();
     }
+  } else if (api_name == kCustomSSDPostprocessing) {
+    ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
+    if (failed(convert_ssd_postprocess.VerifySignature()) ||
+        failed(convert_ssd_postprocess.RewriteFunc())) {
+      return signalPassFailure();
+    }
   }
 }
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
index 07b7aac..783f21f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
@@ -69,6 +69,11 @@
 // training quantization simpler.
 class PrepareQuantizePass
     : public PassWrapper<PrepareQuantizePass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<TFL::TensorFlowLiteDialect,
+                    ::mlir::quant::QuantizationDialect>();
+  }
+
  public:
   // Constructor used by the PassRegistration and enforce uint8 quantization.
   // This is only used by test.
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 5d69e09..c2acad3 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -40,6 +40,7 @@
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
@@ -79,13 +80,23 @@
 // Prepare TF operations in functions for subsequent legalization.
 class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
  public:
-  explicit PrepareTFPass() : unfold_batch_matmul_(true) {}
-  explicit PrepareTFPass(bool unfold_batch_matmul)
-      : unfold_batch_matmul_(unfold_batch_matmul) {}
+  PrepareTFPass() = default;
+  PrepareTFPass(const PrepareTFPass &) {}
+  explicit PrepareTFPass(bool unfold_batch_matmul) {
+    unfold_batch_matmul_ = unfold_batch_matmul;
+  }
   void runOnFunction() override;
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
+                    TFL::TensorFlowLiteDialect>();
+  }
+
  private:
-  bool unfold_batch_matmul_;
+  Option<bool> unfold_batch_matmul_{
+      *this, "tfl-unfold-batch-matmul",
+      llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
+      llvm::cl::init(true)};
 };
 
 template <class TFFakeQuantOp>
diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc
index 7838ab1..b32da24 100644
--- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc
@@ -46,7 +46,7 @@
   } else if (auto complex_type = element_type.dyn_cast<mlir::ComplexType>()) {
     auto etype = complex_type.getElementType();
     if (etype.isF32()) {
-      auto dialect = etype.getContext()->getRegisteredDialect("tf");
+      auto dialect = etype.getContext()->getLoadedDialect("tf");
       tensorflow::TensorProto repr;
       repr.set_dtype(tensorflow::DT_COMPLEX64);
 
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
index 081ba7a..f26689f 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
@@ -93,8 +93,9 @@
   LstmUtilsTest() {}
 
   void SetUp() override {
-    RegisterDialects();
     context_ = std::make_unique<mlir::MLIRContext>();
+    context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
+                          TensorFlowLiteDialect>();
     builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
     fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
     fused_lstm_func_cifg_ =
@@ -109,12 +110,6 @@
     builder_.reset();
   }
 
-  void RegisterDialects() {
-    mlir::registerDialect<mlir::StandardOpsDialect>();
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    mlir::registerDialect<TensorFlowLiteDialect>();
-  }
-
   FuncOp fused_lstm_func_;
   FuncOp fused_lstm_func_cifg_;
   FuncOp fused_ln_lstm_func_;
diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc
new file mode 100644
index 0000000..e462d4f
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc
@@ -0,0 +1,174 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+
+namespace mlir {
+namespace TFL {
+
+namespace {
+
+// TODO(b/162842801): Consolidate all util definitions of kTFImplements.
+constexpr char kTFImplements[] = "tf._implements";
+constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
+constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
+
+inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
+                                       const std::string& content) {
+  ShapedType type = RankedTensorType::get(
+      {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
+  return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
+                                 type,
+                                 StringRef(content.data(), content.size()));
+}
+
+}  // namespace
+
+void ConvertNMSPaddedFunc::RewriteFunc() {
+  func_.setAttr(kTFImplements,
+                StringAttr::get(kTfNMSPadded, func_.getContext()));
+  Value boxes = func_.getArgument(0);
+  Value scores = func_.getArgument(1);
+  Value max_output_size = func_.getArgument(2);
+  Value iou_threshold = func_.getArgument(3);
+  Value score_threshold = func_.getArgument(4);
+  auto output_type0 = func_.getType().getResult(0);
+  auto output_type1 = func_.getType().getResult(1);
+
+  OpBuilder builder(func_.getBody());
+  auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
+      func_.getLoc(), output_type0, output_type1, boxes, scores,
+      max_output_size, iou_threshold, score_threshold);
+
+  builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
+}
+
+LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
+  // Verify high-level function signature.
+  // Relevant argument characteristics are checked by the TFL op definition.
+  if (func_.getNumArguments() < 5) {
+    return func_.emitError()
+           << "Invalid number of arguments to "
+              "non_max_suppression_padded_v2 (need atleast 5): "
+           << func_.getNumArguments();
+  }
+  if (func_.getType().getNumResults() != 2) {
+    return func_.emitError() << "Invalid number of results from "
+                                "non_max_suppression_padded_v2 (need 2): "
+                             << func_.getType().getNumResults();
+  }
+  // The TFLite fused op does not support batching yet.
+  // TODO(b/158709815): Add support for batches with padded NMS.
+  auto boxes_type = func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
+  if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
+    return func_.emitError() << "TFLite does not support batched input for "
+                                "non_max_suppression_padded";
+  }
+  return success();
+}
+
+LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
+  func_.eraseBody();
+  func_.addEntryBlock();
+  func_.setAttr(kTFImplements,
+                StringAttr::get(kCustomSSDPostprocessing, func_.getContext()));
+
+  OpBuilder builder(func_.getBody());
+  std::string custom_option_buffer;
+  if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(),
+                                    custom_option_buffer))) {
+    return failure();
+  }
+  auto op = builder.create<CustomOp>(
+      func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
+      kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
+  builder.create<ReturnOp>(func_.getLoc(), op.getResults());
+
+  return success();
+}
+
+LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
+    FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
+  flexbuffers::Builder fbb;
+  size_t start_map = fbb.StartMap();
+
+  if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
+      failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
+      failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
+      failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
+      failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
+      failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
+      failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
+      failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
+      failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
+    return failure();
+  auto use_regular_nms =
+      attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
+  if (!use_regular_nms) {
+    return func.emitError()
+           << "use_regular_nms attribute is not set or not a bool";
+  }
+  fbb.Int("use_regular_nms", use_regular_nms.getValue());
+
+  fbb.EndMap(start_map);
+  fbb.Finish();
+  custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
+  return success();
+}
+
+LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
+    FuncOp func, DictionaryAttr attrs, const std::string& attribute,
+    flexbuffers::Builder* builder) {
+  auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
+  if (!int_attr) {
+    return func.emitError()
+           << attribute.c_str() << " attribute is not set or not an integer";
+  }
+  builder->Int(attribute.c_str(), int_attr.getInt());
+  return success();
+}
+
+LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
+    FuncOp func, DictionaryAttr attrs, const std::string& attribute,
+    flexbuffers::Builder* builder) {
+  auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
+  if (!float_attr) {
+    return func.emitError()
+           << attribute.c_str() << " attribute is not set or not a float";
+  }
+  builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
+  return success();
+}
+
+LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
+  // Verify high-level function signature.
+  if (func_.getNumArguments() != 3) {
+    return func_.emitError()
+           << "Invalid number of arguments to " << kCustomSSDPostprocessing
+           << ": " << func_.getNumArguments();
+  }
+  if (func_.getType().getNumResults() != 4) {
+    return func_.emitError()
+           << "Invalid number of results from " << kCustomSSDPostprocessing
+           << ": " << func_.getType().getNumResults();
+  }
+  return success();
+}
+
+}  // namespace TFL
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.h b/tensorflow/compiler/mlir/lite/utils/nms_utils.h
new file mode 100644
index 0000000..6a9035e
--- /dev/null
+++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.h
@@ -0,0 +1,76 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This header file defines common utils used by TFLite transformation
+// passes to work with NMS ops in TFLite.
+
+#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
+
+#include <string>
+
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
+
+namespace mlir {
+namespace TFL {
+
+// Abstracts the conversion of the padded NMS composite function.
+class ConvertNMSPaddedFunc {
+ public:
+  explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
+
+  void RewriteFunc();
+
+  LogicalResult VerifySignature();
+
+ private:
+  FuncOp func_;
+};
+
+// Abstracts the conversion of the SSD post-processing composite function to
+// TFLite.
+class ConvertSSDPostProcessFunc {
+ public:
+  explicit ConvertSSDPostProcessFunc(FuncOp func, mlir::TF::FuncAttr attr)
+      : func_(func), attr_(attr) {}
+
+  LogicalResult RewriteFunc();
+
+  LogicalResult VerifySignature();
+
+ private:
+  LogicalResult CreateNMSCustomOptions(FuncOp func, DictionaryAttr attrs,
+                                       std::string& custom_option_buffer);
+
+  LogicalResult AddIntAttr(FuncOp func, DictionaryAttr attrs,
+                           const std::string& attribute,
+                           flexbuffers::Builder* builder);
+
+  LogicalResult AddFloatAttr(FuncOp func, DictionaryAttr attrs,
+                             const std::string& attribute,
+                             flexbuffers::Builder* builder);
+
+  FuncOp func_;
+  mlir::TF::FuncAttr attr_;
+};
+
+}  // end namespace TFL
+}  // end namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
index 96d22cb..cce8038 100644
--- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
@@ -47,6 +47,7 @@
 
 constexpr char kNgrams[] = "tftext:Ngrams";
 constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
+constexpr char kCustomSgnnProjection[] = "tftext:custom:SgnnProjection";
 constexpr char kTFImplements[] = "tf._implements";
 
 using mlir::TF::FuncAttr;
@@ -56,9 +57,9 @@
                                        const std::string& content) {
   ShapedType type = RankedTensorType::get(
       {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
-  return OpaqueElementsAttr::get(
-      builder->getContext()->getRegisteredDialect("tfl"), type,
-      StringRef(content.data(), content.size()));
+  return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
+                                 type,
+                                 StringRef(content.data(), content.size()));
 }
 
 inline TensorType GetInputType(FuncOp func, int idx) {
@@ -269,6 +270,85 @@
   return success();
 }
 
+LogicalResult VerifySgnnProjection(FuncOp func, FuncAttr attr) {
+  if (func.getType().getNumInputs() != 2 ||
+      func.getType().getNumResults() != 1) {
+    return func.emitError() << "Mismatched number of inputs and outputs.";
+  }
+  auto values_type = GetInputType(func, 0);
+  if (!values_type || !values_type.getElementType().isa<StringType>()) {
+    return func.emitError() << "First input should be a string tensor";
+  }
+  auto row_splits_type = GetInputType(func, 1);
+  if (!row_splits_type ||
+      !row_splits_type.getElementType().isa<IntegerType>()) {
+    return func.emitError() << "Second input should be an integer tensor";
+  }
+
+  auto hash_seed =
+      attr.GetAttrs().get("hash_seed").dyn_cast_or_null<ArrayAttr>();
+  if (!hash_seed) {
+    return func.emitError()
+           << "'hash_seed' attribute is not set or not an array";
+  }
+  auto output_type = GetResultType(func, 0);
+  if (!output_type || !output_type.getElementType().isa<FloatType>() ||
+      !RankEquals(output_type, 2)) {
+    return func.emitError() << "Output should be a 2D float tensor.";
+  }
+  if (output_type.getDimSize(1) != hash_seed.size()) {
+    return func.emitError()
+           << "Output 2nd dimension should be the num of hash seeds.";
+  }
+
+  auto buckets = attr.GetAttrs().get("buckets").dyn_cast_or_null<IntegerAttr>();
+  if (!buckets) {
+    return func.emitError() << "'buckets' attribute is not set or not int";
+  }
+
+  return success();
+}
+
+LogicalResult CreateSgnnProjectionCustomOption(
+    FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
+  flexbuffers::Builder fbb;
+  size_t start_map = fbb.StartMap();
+
+  auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null<ArrayAttr>();
+  auto vector_start = fbb.StartVector("hash_seed");
+  for (int i = 0; i < hash_seed.size(); i++) {
+    fbb.Add(static_cast<int32_t>(
+        (hash_seed.getValue().data() + i)->dyn_cast<IntegerAttr>().getInt()));
+  }
+  fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false);
+
+  auto buckets = attrs.get("buckets").dyn_cast_or_null<IntegerAttr>();
+  fbb.Int("buckets", buckets.getInt());
+
+  fbb.EndMap(start_map);
+  fbb.Finish();
+  custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
+  return success();
+}
+
+LogicalResult ConvertSgnnProjection(FuncOp func, llvm::StringRef api,
+                                    FuncAttr attr) {
+  // See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py
+  func.eraseBody();
+  func.addEntryBlock();
+  func.setAttr(kTFImplements, attr);
+  OpBuilder builder(func.getBody());
+  std::string custom_option_buffer;
+  if (failed(CreateSgnnProjectionCustomOption(func, attr.GetAttrs(),
+                                              custom_option_buffer))) {
+    return failure();
+  }
+  auto op = builder.create<CustomOp>(
+      func.getLoc(), func.getType().getResults(), func.getArguments(), api,
+      CustomOption(&builder, custom_option_buffer));
+  builder.create<ReturnOp>(func.getLoc(), op.getResults());
+  return success();
+}
 }  // namespace
 
 LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
@@ -281,6 +361,10 @@
     if (succeeded(VerifyNgrams(func))) {
       return ConvertNgrams(func, api, attr);
     }
+  } else if (api.str() == kCustomSgnnProjection) {
+    if (succeeded(VerifySgnnProjection(func, attr))) {
+      return ConvertSgnnProjection(func, api, attr);
+    }
   }
   return failure();
 }
diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
index 8be6fac..d97e12f 100644
--- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
+++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
@@ -91,16 +91,14 @@
   return *global;
 }
 
-static void RegisterDialects() {
-  static bool init_once = []() {
-    mlir::registerDialect<mlir::StandardOpsDialect>();
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    mlir::registerDialect<mlir::shape::ShapeDialect>();
-    mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
-    mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
-    return true;
-  }();
-  (void)init_once;
+static void RegisterDialects(mlir::DialectRegistry& registry) {
+  // clang-format off
+  registry.insert<mlir::StandardOpsDialect,
+                  mlir::TF::TensorFlowDialect,
+                  mlir::shape::ShapeDialect,
+                  mlir::tf_device::TensorFlowDeviceDialect,
+                  mlir::tf_executor::TensorFlowExecutorDialect>();
+  // clang-format on
 }
 
 Status MlirFunctionOptimizationPass::Run(
@@ -126,8 +124,8 @@
                           << " passes)";
 
   GraphDebugInfo debug_info;
-  RegisterDialects();
   mlir::MLIRContext context;
+  RegisterDialects(context.getDialectRegistry());
   GraphImportConfig import_config;
   import_config.graph_as_function = true;
   import_config.control_outputs = *control_ret_node_names;
@@ -206,8 +204,8 @@
                           << " passes)";
 
   GraphDebugInfo debug_info;
-  RegisterDialects();
   mlir::MLIRContext context;
+  RegisterDialects(context.getDialectRegistry());
   GraphImportConfig import_config;
   import_config.upgrade_legacy = true;
   // Restrict functionalization to TPU nodes to avoid problems in v1 session
diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc
index bce0ed4..6b60574 100644
--- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc
+++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc
@@ -28,6 +28,7 @@
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
 
 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
   return absl::string_view(ref.data(), ref.size());
@@ -103,62 +104,16 @@
 
 bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
 
-namespace {
-// Derives name from location.
-std::string GetNameFromLoc(mlir::Location loc) {
-  llvm::SmallVector<llvm::StringRef, 8> loc_names;
-  llvm::SmallVector<mlir::Location, 8> locs;
-  locs.push_back(loc);
-  bool names_is_nonempty = false;
-
-  while (!locs.empty()) {
-    mlir::Location curr_loc = locs.pop_back_val();
-
-    if (auto name_loc = curr_loc.dyn_cast<mlir::NameLoc>()) {
-      // Add name in NameLoc. For NameLoc we also account for names due to ops
-      // in functions where the op's name is first.
-      auto name = name_loc.getName().strref().split('@').first;
-      loc_names.push_back(name);
-      if (!name.empty()) names_is_nonempty = true;
-      continue;
-    } else if (auto call_loc = curr_loc.dyn_cast<mlir::CallSiteLoc>()) {
-      // Add name if CallSiteLoc's callee has a NameLoc (as should be the
-      // case if imported with DebugInfo).
-      if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>()) {
-        auto name = name_loc.getName().strref().split('@').first;
-        loc_names.push_back(name);
-        if (!name.empty()) names_is_nonempty = true;
-        continue;
-      }
-    } else if (auto fused_loc = curr_loc.dyn_cast<mlir::FusedLoc>()) {
-      // Push all locations in FusedLoc in reverse order, so locations are
-      // visited based on order in FusedLoc.
-      auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
-      locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
-      continue;
-    }
-
-    // Location is not a supported, so an empty StringRef is added.
-    loc_names.push_back(llvm::StringRef());
-  }
-
-  if (names_is_nonempty)
-    return llvm::join(loc_names.begin(), loc_names.end(), ";");
-
-  return "";
-}
-}  // anonymous namespace
-
 std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
   if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
-    auto name_from_loc = GetNameFromLoc(op->getLoc());
+    auto name_from_loc = mlir::GetNameFromLoc(op->getLoc());
     if (!name_from_loc.empty()) return name_from_loc;
     // If the location is none of the expected types, then simply use name
     // generated using the op type.
     return std::string(op->getName().getStringRef());
   }
   auto val = op_or_val.dyn_cast<mlir::Value>();
-  auto name_from_loc = GetNameFromLoc(val.getLoc());
+  auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
   if (!name_from_loc.empty()) return name_from_loc;
   // If the location is none of the expected types, then simply use name
   // generated using the op type. Follow TF convention and append the result
diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD
index 5bbfba7..6a47be3 100644
--- a/tensorflow/compiler/mlir/python/BUILD
+++ b/tensorflow/compiler/mlir/python/BUILD
@@ -10,6 +10,7 @@
     deps = [
         "//tensorflow/c:tf_status",
         "//tensorflow/c:tf_status_helper",
+        "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
         "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
         "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
@@ -35,6 +36,7 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc
index 5ce0ca8..8bec288 100644
--- a/tensorflow/compiler/mlir/python/mlir.cc
+++ b/tensorflow/compiler/mlir/python/mlir.cc
@@ -16,11 +16,13 @@
 #include <string>
 
 #include "llvm/Support/raw_ostream.h"
+#include "mlir/InitAllPasses.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "tensorflow/c/tf_status.h"
 #include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
@@ -150,6 +152,7 @@
                                         bool show_debug_info,
                                         TF_Status *status) {
   mlir::MLIRContext context;
+  mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
   mlir::OwningModuleRef module;
   {
     mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
@@ -164,6 +167,7 @@
   mlir::PassManager pm(&context);
   std::string error;
   llvm::raw_string_ostream error_stream(error);
+  mlir::registerAllPasses();
   if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
     TF_SetStatus(status, TF_INVALID_ARGUMENT,
                  ("Invalid pass_pipeline: " + error_stream.str()).c_str());
diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
index 63ca4c7..6cd49cf 100644
--- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
+++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc
@@ -22,22 +22,25 @@
 #include "mlir/Parser.h"  // from @llvm-project
 #include "pybind11/pybind11.h"
 #include "pybind11/stl.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/python/lib/core/pybind11_lib.h"
 #include "tensorflow/python/lib/core/pybind11_status.h"
 
 PYBIND11_MODULE(mlir_wrapper, m) {
-  m.def("registerDialects", []() {
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
-    mlir::registerDialect<mlir::StandardOpsDialect>();
+  m.def("preloadTensorFlowDialects", [](mlir::MLIRContext &context) {
+    mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
+    context.getDialectRegistry().loadAll(&context);
   });
+
   m.def("verify", [](std::string input) {
     llvm::SourceMgr SM = llvm::SourceMgr();
     SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
                           llvm::SMLoc());
     mlir::MLIRContext ctx;
+    mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
+    ctx.getDialectRegistry().loadAll(&ctx);
     auto module = mlir::parseSourceFile(SM, &ctx);
     if (!module) {
       return false;
diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py
index 45c8dce..f987018 100644
--- a/tensorflow/compiler/mlir/runlit.cfg.py
+++ b/tensorflow/compiler/mlir/runlit.cfg.py
@@ -74,7 +74,7 @@
     'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
     'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
     'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir',
-    'kernel-gen-opt', 'xla-thunks-opt'
+    'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt'
 ]
 tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
 llvm_config.add_tool_substitutions(tools, tool_dirs)
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 0ddf390..1344ded 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -513,6 +513,7 @@
         "ir/tf_saved_model.cc",
     ],
     hdrs = [
+        "dialect_registration.h",
         "ir/tf_device.h",
         "ir/tf_executor.h",
         "ir/tf_ops.h",
@@ -777,6 +778,7 @@
         "transforms/sink_constant.cc",
         "transforms/stack_ops_decomposition.cc",
         "transforms/tensor_array_ops_decomposition.cc",
+        "transforms/tensor_device_copy_conversion.cc",
         "transforms/tensor_list_ops_decomposition.cc",
         "transforms/test_resource_alias_analysis.cc",
         "transforms/test_side_effect_analysis.cc",
@@ -792,6 +794,7 @@
         "transforms/tpu_identity_pruning.cc",
         "transforms/tpu_merge_variables_with_execute.cc",
         "transforms/tpu_outside_compilation_cluster.cc",
+        "transforms/tpu_resource_read_for_write.cc",
         "transforms/tpu_rewrite_pass.cc",
         "transforms/tpu_sharding_identification_pass.cc",
         "transforms/tpu_space_to_depth_pass.cc",
@@ -820,6 +823,7 @@
         ":device_util",
         ":error_util",
         ":export_tf_dialect_op",
+        ":lower_tf_lib",
         ":mangling_util",
         ":tensorflow",
         ":tensorflow_analysis",
@@ -957,6 +961,7 @@
         "//tensorflow/cc/saved_model:loader_lite",
         "//tensorflow/cc/saved_model:loader_util",
         "//tensorflow/compiler/jit:shape_inference_helpers",
+        "//tensorflow/compiler/mlir:name_utils",
         "//tensorflow/compiler/mlir:op_or_arg_name_mapper",
         "//tensorflow/compiler/tf2xla:functionalize_control_flow",
         "//tensorflow/compiler/xla:status_macros",
@@ -1072,6 +1077,7 @@
         ":export_utils",
         ":tensorflow",
         "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/stream_executor/lib",
@@ -1087,6 +1093,7 @@
     srcs = ["translate/translate_tf_dialect_op.cc"],
     deps = [
         ":export_tf_dialect_op",
+        ":tensorflow",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
@@ -1313,6 +1320,7 @@
     deps = [
         ":convert_graphdef",
         ":mlir_roundtrip_flags",
+        ":tensorflow",
         "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
@@ -1407,6 +1415,7 @@
     deps = [
         ":convert_graphdef",
         ":mlir_roundtrip_flags",
+        ":tensorflow",
         ":translate_cl_options",
         ":translate_lib",
         "//tensorflow/core:protos_all_cc",
@@ -1508,6 +1517,7 @@
     "@llvm-project//mlir:TransformUtils",
     "@llvm-project//mlir:Transforms",
     "//tensorflow/compiler/mlir/hlo:hlo",
+    "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
     "//tensorflow/compiler/mlir/hlo:sink_constants_to_control_flow",
     "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
     "//tensorflow/compiler/mlir/xla:type_to_shape",
diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD
index 243f4b5..5c6f396 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD
@@ -2,7 +2,6 @@
     "//tensorflow:tensorflow.bzl",
     "tf_copts",
     "tf_cuda_library",
-    "tfe_xla_copts",
 )
 
 package(
@@ -20,7 +19,7 @@
     srcs = [
         "c_api_unified_experimental_mlir.cc",
     ],
-    copts = tf_copts() + tfe_xla_copts(),
+    copts = tf_copts(),
     deps = [
         "//tensorflow/c:c_api",
         "//tensorflow/c:tensor_interface",
diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
index c62d62a..6bfe4c3 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
+++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
@@ -43,6 +43,7 @@
 #include "tensorflow/c/tf_status.h"
 #include "tensorflow/c/tf_status_helper.h"
 #include "tensorflow/c/tf_status_internal.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -74,15 +75,9 @@
 
 namespace {
 
-static void RegisterDialects() {
-  static bool init_once = []() {
-    mlir::registerDialect<mlir::StandardOpsDialect>();
-    mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
-    mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    return true;
-  }();
-  (void)init_once;
+void RegisterDialects(mlir::MLIRContext& ctx) {
+  mlir::RegisterAllTensorFlowDialects(ctx.getDialectRegistry());
+  ctx.getDialectRegistry().loadAll(&ctx);
 }
 
 Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
@@ -239,6 +234,7 @@
       : TracingContext(kMlir),
         context_(std::make_unique<MLIRContext>()),
         builder_(context_.get()) {
+    RegisterDialects(*context_);
     // TODO(aminim) figure out the location story here
     module_ = ModuleOp::create(builder_.getUnknownLoc());
     func_ = FuncOp::create(builder_.getUnknownLoc(), name,
@@ -456,7 +452,8 @@
   return Unimplemented("SetAttrFloat has not been implemented yet.");
 }
 Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
-  return Unimplemented("SetAttrBool has not been implemented yet.");
+  attrs_[attr_name] = BoolAttr::get(value, context_);
+  return Status::OK();
 }
 Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
                                     const int num_dims) {
@@ -666,7 +663,6 @@
 
 extern "C" {
 TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
-  RegisterDialects();
   return new MlirFunctionContext(fn_name);
 }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/dialect_registration.h b/tensorflow/compiler/mlir/tensorflow/dialect_registration.h
new file mode 100644
index 0000000..a63bfd1
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/dialect_registration.h
@@ -0,0 +1,37 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_
+#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
+
+namespace mlir {
+// Inserts all the TensorFlow dialects in the provided registry. This is
+// intended for tools that need to register dialects before parsing .mlir files.
+inline void RegisterAllTensorFlowDialects(DialectRegistry &registry) {
+  registry.insert<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
+                  mlir::tf_device::TensorFlowDeviceDialect,
+                  mlir::tf_executor::TensorFlowExecutorDialect,
+                  mlir::tf_saved_model::TensorFlowSavedModelDialect>();
+}
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index ea9ae5d..eced738 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -250,33 +250,6 @@
 // tf_executor.fetch
 //===----------------------------------------------------------------------===//
 
-namespace {
-
-void Print(FetchOp fetch, OpAsmPrinter &p) {
-  p << fetch.getOperationName();
-  if (fetch.getNumOperands() > 0) {
-    p << ' ';
-    p.printOperands(fetch.operand_begin(), fetch.operand_end());
-    p << " : ";
-    interleaveComma(fetch.getOperandTypes(), p);
-  }
-  p.printOptionalAttrDict(fetch.getAttrs());
-}
-
-ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> opInfo;
-  SmallVector<Type, 2> types;
-  llvm::SMLoc loc = parser.getCurrentLocation();
-  return failure(parser.parseOperandList(opInfo) ||
-                 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
-                 parser.resolveOperands(opInfo, types, loc, result.operands) ||
-                 parser.parseOptionalAttrDict(result.attributes)
-
-  );
-}
-
-}  // anonymous namespace
-
 //===----------------------------------------------------------------------===//
 // tf_executor.island
 //===----------------------------------------------------------------------===//
@@ -411,31 +384,6 @@
 // tf_executor.yield
 //===----------------------------------------------------------------------===//
 
-namespace {
-
-void Print(YieldOp yield, OpAsmPrinter &p) {
-  p << yield.getOperationName();
-  if (yield.getNumOperands() > 0) {
-    p << ' ';
-    p.printOperands(yield.operand_begin(), yield.operand_end());
-    p << " : ";
-    interleaveComma(yield.getOperandTypes(), p);
-  }
-  p.printOptionalAttrDict(yield.getAttrs());
-}
-
-ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> op_info;
-  SmallVector<Type, 2> types;
-  llvm::SMLoc loc = parser.getCurrentLocation();
-  return failure(parser.parseOperandList(op_info) ||
-                 (!op_info.empty() && parser.parseColonTypeList(types)) ||
-                 parser.resolveOperands(op_info, types, loc, result.operands) ||
-                 parser.parseOptionalAttrDict(result.attributes));
-}
-
-}  // anonymous namespace
-
 //===----------------------------------------------------------------------===//
 // tf_executor.Switch
 //===----------------------------------------------------------------------===//
@@ -848,23 +796,6 @@
   return success();
 }
 
-void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) {
-  p << next_iteration.getOperationName() << " : " << next_iteration.getType(0);
-  p.printOptionalAttrDict(next_iteration.getAttrs());
-}
-
-ParseResult ParseNextIterationSourceOp(OpAsmParser &parser,
-                                       OperationState &result) {
-  SmallVector<Type, 1> types;
-  if (parser.parseColonTypeList(types)) return failure();
-
-  MLIRContext *context = parser.getBuilder().getContext();
-  Type token_type = TokenType::get(context);
-  Type control_type = ControlType::get(context);
-  result.addTypes({types.front(), token_type, control_type});
-  return parser.parseOptionalAttrDict(result.attributes);
-}
-
 }  // anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -891,36 +822,6 @@
   return success();
 }
 
-void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) {
-  p << next_iteration.getOperationName() << " [";
-  p.printOperand(next_iteration.getOperand(0));
-  p << "] ";
-  p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
-  p << " : " << next_iteration.getOperand(1).getType();
-  p.printOptionalAttrDict(next_iteration.getAttrs());
-}
-
-ParseResult ParseNextIterationSinkOp(OpAsmParser &parser,
-                                     OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> op_infos;
-  llvm::SMLoc loc = parser.getCurrentLocation();
-
-  // First type is always the token consumed from the NextIteration.source
-  Type token_type = TokenType::get(parser.getBuilder().getContext());
-  SmallVector<Type, 1> types = {token_type};
-
-  if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) ||
-      parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
-    return failure();
-
-  Type control_type = ControlType::get(parser.getBuilder().getContext());
-  types.append(op_infos.size() - 2, control_type);
-  if (parser.resolveOperands(op_infos, types, loc, result.operands))
-    return failure();
-
-  return parser.parseOptionalAttrDict(result.attributes);
-}
-
 }  // anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -959,32 +860,6 @@
 // tf_executor.ControlTrigger
 //===----------------------------------------------------------------------===//
 
-namespace {
-
-void Print(ControlTriggerOp trigger, OpAsmPrinter &p) {
-  p << trigger.getOperationName() << ' ';
-  p.printOperands(trigger.getOperands());
-  p.printOptionalAttrDict(trigger.getAttrs());
-}
-
-ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> op_infos;
-  SmallVector<Type, 1> types;
-  llvm::SMLoc loc = parser.getCurrentLocation();
-
-  if (parser.parseOperandList(op_infos)) return failure();
-  Type control_type = ControlType::get(parser.getBuilder().getContext());
-  types.append(op_infos.size(), control_type);
-  if (parser.resolveOperands(op_infos, types, loc, result.operands))
-    return failure();
-
-  // Single control as the only output
-  result.types.push_back(control_type);
-  return parser.parseOptionalAttrDict(result.attributes);
-}
-
-}  // anonymous namespace
-
 //===----------------------------------------------------------------------===//
 // tf_executor.LoopCond
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
index 3081018..de2d248 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
@@ -47,10 +47,12 @@
 }
 
 // Control type.
-def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">;
+def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">,
+                     BuildableType<"$_builder.getType<ControlType>()">;
 
 // Token type.
-def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">;
+def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">,
+                   BuildableType<"$_builder.getType<TokenType>()">;
 
 // TODO(hinsu): Define and use TensorType instead of AnyType for data operands
 // and results. For example, MergeOp output type.
@@ -148,7 +150,11 @@
     }]>
    ];
 
+  let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
+
   let verifier = ?;
+  let printer = ?;
+  let parser = ?;
 }
 
 def TfExecutor_IslandOp : TfExecutor_Op<"island",
@@ -229,7 +235,11 @@
     }]>
    ];
 
+  let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
+
   let verifier = ?;
+  let printer = ?;
+  let parser = ?;
 }
 
 def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
@@ -466,6 +476,10 @@
     }
   }];
 
+  let assemblyFormat = "`:` type($output) attr-dict";
+
+  let printer = ?;
+  let parser = ?;
 }
 
 
@@ -527,6 +541,11 @@
       result.attributes.append(attributes.begin(), attributes.end());
     }]>
    ];
+
+  let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict";
+
+  let printer = ?;
+  let parser = ?;
 }
 
 def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
@@ -552,7 +571,7 @@
        .Attr("T: type")
 
     For example:
-     %1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32>
+     %1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"}
 
     Note: Additional result corresponds to the control output.
   }];
@@ -607,6 +626,11 @@
       result.attributes.append(attributes.begin(), attributes.end());
     }]>
    ];
+
+  let assemblyFormat = "$controlInputs attr-dict";
+
+  let printer = ?;
+  let parser = ?;
 }
 
 def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index d4bae60..faf7d42 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -52,6 +52,12 @@
 def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes acos of x element-wise.";
 
+  let description = [{
+Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`.
+
+  Input range is `[-1, 1]` and the output has a range of `[0, pi]`.
+  }];
+
   let arguments = (ins
     TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
   );
@@ -94,6 +100,10 @@
   let description = [{
 *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+
+Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor.
+
+Both input and output have a range `(-inf, inf)`.
   }];
 
   let arguments = (ins
@@ -136,31 +146,6 @@
   let hasFolder = 1;
 }
 
-def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
-                 WithBroadcastableBinOpBuilder {
-  let summary = "Returns x + y element-wise.";
-
-  let description = [{
-*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-  }];
-
-  let arguments = (ins
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x,
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y
-  );
-
-  let results = (outs
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-
-  let hasCanonicalizer = 1;
-
-  let hasFolder = 1;
-}
-
 def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> {
   let summary = "Adjust the contrast of one or more images.";
 
@@ -1740,6 +1725,24 @@
   let hasCanonicalizer = 1;
 }
 
+def TF_ConfigureDistributedTPUOp : TF_Op<"ConfigureDistributedTPU", []> {
+  let summary = [{
+Sets up the centralized structures for a distributed TPU system.
+  }];
+
+  let arguments = (ins
+    StrAttr:$embedding_config,
+    StrAttr:$tpu_embedding_config,
+    DefaultValuedAttr<BoolAttr, "false">:$is_global_init,
+    DefaultValuedAttr<BoolAttr, "false">:$enable_whole_mesh_compilations,
+    DefaultValuedAttr<BoolAttr, "true">:$compilation_failure_closes_chips
+  );
+
+  let results = (outs
+    TF_StrTensor:$topology
+  );
+}
+
 def TF_ConjOp : TF_Op<"Conj", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns the complex conjugate of a complex number.";
 
@@ -2236,6 +2239,48 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> {
+  let summary = "Permute input tensor from `src_format` to `dst_format`.";
+
+  let description = [{
+Input tensor must be a vector of size 4, or a 4x2 tensor.
+
+For example, with `src_format` of `NHWC`, `dst_format` of `NCHW`, and inputs:
+```
+[1, 2, 3, 4]
+```
+and
+```
+[[1, 2, 3, 4],
+ [5, 6, 7, 8]]
+```
+, the outputs will be (respectively):
+```
+[1, 4, 2, 3]
+```
+and
+```
+[[1, 4, 2, 3],
+ [5, 8, 6, 7]]
+```
+  }];
+
+  let arguments = (ins
+    TF_I32OrI64Tensor:$x,
+
+    DefaultValuedAttr<StrAttr, "NHWC">:$src_format,
+    DefaultValuedAttr<StrAttr, "NCHW">:$dst_format
+  );
+
+  let results = (outs
+    TF_I32OrI64Tensor:$y
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+
+  let verifier = [{ return Verify(*this); }];
+}
+
 def TF_DebugIdentityV2Op : TF_Op<"DebugIdentityV2", []> {
   let summary = "Debug Identity V2 Op.";
 
@@ -2744,27 +2789,6 @@
   let hasFolder = 1;
 }
 
-def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
-                    WithBroadcastableBinOpBuilder {
-  let summary = "Returns 0 if the denominator is zero.";
-
-  let description = [{
-*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-  }];
-
-  let arguments = (ins
-    TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x,
-    TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y
-  );
-
-  let results = (outs
-    TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-}
-
 def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [NoSideEffect, SameVariadicOperandSize]> {
   let summary = [{
 Interleave the values from the `data` tensors into a single tensor.
@@ -3811,6 +3835,95 @@
   }];
 }
 
+def TF_FusedBatchNormV2Op : TF_Op<"FusedBatchNormV2", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
+  let summary = "Batch normalization.";
+
+  let description = [{
+Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+The size of 1D Tensors matches the dimension C of the 4D Tensors.
+  }];
+
+  let arguments = (ins
+    TensorOf<[BF16, F16, F32]>:$x,
+    F32Tensor:$scale,
+    F32Tensor:$offset,
+    F32Tensor:$mean,
+    F32Tensor:$variance,
+
+    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+    DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
+    DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+    DefaultValuedAttr<BoolAttr, "true">:$is_training
+  );
+
+  let results = (outs
+    TensorOf<[BF16, F16, F32]>:$y,
+    F32Tensor:$batch_mean,
+    F32Tensor:$batch_variance,
+    F32Tensor:$reserve_space_1,
+    F32Tensor:$reserve_space_2
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
+
+  let extraClassDeclaration = [{
+    // TF_FoldOperandsTransposeInterface:
+    SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
+    SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
+    LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
+
+    // TF_LayoutSensitiveInterface:
+    StringRef GetOptimalLayout(const RuntimeDevices& devices);
+    LogicalResult UpdateDataFormat(StringRef data_format);
+  }];
+}
+
+def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
+  let summary = "Batch normalization.";
+
+  let description = [{
+Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+The size of 1D Tensors matches the dimension C of the 4D Tensors.
+  }];
+
+  let arguments = (ins
+    TensorOf<[BF16, F16, F32]>:$x,
+    F32Tensor:$scale,
+    F32Tensor:$offset,
+    F32Tensor:$mean,
+    F32Tensor:$variance,
+
+    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+    DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
+    DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+    DefaultValuedAttr<BoolAttr, "true">:$is_training
+  );
+
+  let results = (outs
+    TensorOf<[BF16, F16, F32]>:$y,
+    F32Tensor:$batch_mean,
+    F32Tensor:$batch_variance,
+    F32Tensor:$reserve_space_1,
+    F32Tensor:$reserve_space_2,
+    F32Tensor:$reserve_space_3
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
+
+  let extraClassDeclaration = [{
+    // TF_FoldOperandsTransposeInterface:
+    SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
+    SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
+    LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
+
+    // TF_LayoutSensitiveInterface:
+    StringRef GetOptimalLayout(const RuntimeDevices& devices);
+    LogicalResult UpdateDataFormat(StringRef data_format);
+  }];
+}
+
 def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
   let summary = "Gather slices from `params` according to `indices`.";
 
@@ -6171,14 +6284,14 @@
   }];
 
   let arguments = (ins
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
     TF_I32OrI64Tensor:$reduction_indices,
 
     DefaultValuedAttr<BoolAttr, "false">:$keep_dims
   );
 
   let results = (outs
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
   );
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@@ -6282,25 +6395,36 @@
   }];
 }
 
-def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
-                   WithBroadcastableBinOpBuilder {
-  let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise.";
+def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
+  let summary = "Computes the mean of elements across dimensions of a tensor.";
 
   let description = [{
-*NOTE*: `Maximum` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+Reduces `input` along the dimensions given in `axis`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`axis`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
   }];
 
   let arguments = (ins
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x,
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
+    TF_I32OrI64Tensor:$reduction_indices,
+
+    DefaultValuedAttr<BoolAttr, "false">:$keep_dims
   );
 
   let results = (outs
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
   );
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
+
+  let extraClassDeclaration = [{
+    // TF_FoldOperandsTransposeInterface:
+    SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
+    SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
+    LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
+  }];
 }
 
 def TF_MergeSummaryOp : TF_Op<"MergeSummary", [NoSideEffect, SameOperandsAndResultType]> {
@@ -6366,14 +6490,14 @@
   }];
 
   let arguments = (ins
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
     TF_I32OrI64Tensor:$reduction_indices,
 
     DefaultValuedAttr<BoolAttr, "false">:$keep_dims
   );
 
   let results = (outs
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
+    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
   );
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@@ -7804,33 +7928,6 @@
   TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
 }
 
-def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>,
-                   WithBroadcastableBinOpBuilder {
-  let summary = "Returns x / y element-wise for real types.";
-
-  let description = [{
-If `x` and `y` are reals, this will return the floating-point division.
-
-*NOTE*: `Div` supports broadcasting. More about broadcasting
-[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-  }];
-
-  let arguments = (ins
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
-  );
-
-  let results = (outs
-    TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-
-  let hasCanonicalizer = 1;
-
-  let hasFolder = 1;
-}
-
 def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes the reciprocal of x element-wise.";
 
@@ -9240,6 +9337,18 @@
   );
 }
 
+def TF_ShutdownDistributedTPUOp : TF_Op<"ShutdownDistributedTPU", []> {
+  let summary = "Shuts down a running distributed TPU system.";
+
+  let description = [{
+The op returns an error if no system is running.
+  }];
+
+  let arguments = (ins);
+
+  let results = (outs);
+}
+
 def TF_SigmoidOp : TF_Op<"Sigmoid", [NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes sigmoid of `x` element-wise.";
 
@@ -9758,6 +9867,41 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
 }
 
+def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> {
+  let summary = [{
+Multiply matrix "a" by matrix "b".
+  }];
+
+  let description = [{
+The inputs must be two-dimensional matrices and the inner dimension of "a" must
+match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not
+`SparseTensor`s.  This op is optimized for the case where at least one of "a" or
+"b" is sparse, in the sense that they have a large proportion of zero values.
+The breakeven for using this versus a dense matrix multiply on one platform was
+30% zero values in the sparse matrix.
+
+The gradient computation of this operation will only take advantage of sparsity
+in the input gradient when that gradient comes from a Relu.
+  }];
+
+  let arguments = (ins
+    TensorOf<[BF16, F32]>:$a,
+    TensorOf<[BF16, F32]>:$b,
+
+    DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
+    DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
+    DefaultValuedAttr<BoolAttr, "false">:$a_is_sparse,
+    DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse
+  );
+
+  let results = (outs
+    F32Tensor:$product
+  );
+
+  TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>;
+}
+
 def TF_SparseReshapeOp : TF_Op<"SparseReshape", [NoSideEffect]> {
   let summary = [{
 Reshapes a SparseTensor to represent values in a new dense shape.
@@ -10466,6 +10610,36 @@
   TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
 }
 
+def TF_StringToHashBucketFastOp : TF_Op<"StringToHashBucketFast", [NoSideEffect]> {
+  let summary = [{
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+  }];
+
+  let description = [{
+The hash function is deterministic on the content of the string within the
+process and will never change. However, it is not suitable for cryptography.
+This function may be used when CPU time is scarce and inputs are trusted or
+unimportant. There is a risk of adversaries constructing inputs that all hash
+to the same bucket. To prevent this problem, use a strong hash function with
+`tf.string_to_hash_bucket_strong`.
+
+Examples:
+
+>>> tf.strings.to_hash_bucket_fast(["Hello", "TensorFlow", "2.x"], 3).numpy()
+array([0, 2, 2])
+  }];
+
+  let arguments = (ins
+    TF_StrTensor:$input,
+
+    Confined<I64Attr, [IntMinValue<1>]>:$num_buckets
+  );
+
+  let results = (outs
+    I64Tensor:$output
+  );
+}
+
 def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>,
                WithBroadcastableBinOpBuilder {
   let summary = "Returns x - y element-wise.";
@@ -11521,9 +11695,9 @@
   TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>;
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 
-  // TODO(parkers): Add folds for multiples = [1,...].
-  // TODO(parkers): Add errors for negative multiples and multiples.size() !=
-  // input.rank()
+  let verifier = [{ return Verify(*this); }];
+
+  let hasFolder = 1;
 }
 
 def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
@@ -12715,6 +12889,43 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
+def TF__FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> {
+  let summary = "Internal FusedBatchNorm operation: reserved for internal use.";
+
+  let description = [{
+Do not invoke this operator directly in Python. A fusion optimization is
+expected to create these operators.
+  }];
+
+  let arguments = (ins
+    TensorOf<[F16, F32]>:$x,
+    F32Tensor:$scale,
+    F32Tensor:$offset,
+    F32Tensor:$mean,
+    F32Tensor:$variance,
+    Variadic<TensorOf<[F16, F32]>>:$side_input,
+
+    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+    DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
+    DefaultValuedAttr<StrAttr, "Identity">:$activation_mode,
+    DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+    DefaultValuedAttr<BoolAttr, "true">:$is_training
+  );
+
+  let results = (outs
+    TensorOf<[F16, F32]>:$y,
+    F32Tensor:$batch_mean,
+    F32Tensor:$batch_variance,
+    F32Tensor:$reserve_space_1,
+    F32Tensor:$reserve_space_2,
+    F32Tensor:$reserve_space_3
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
+  TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>;
+}
+
 def TF__FusedConv2DOp : TF_Op<"_FusedConv2D", [NoSideEffect]> {
   let summary = [{
 Performs a convolution followed by a specified series of operations.
@@ -12752,7 +12963,8 @@
     DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations,
     DefaultValuedAttr<BoolAttr, "true">:$use_cudnn_on_gpu,
     DefaultValuedAttr<StrArrayAttr, "{}">:$fused_ops,
-    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon
+    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+    DefaultValuedAttr<F32Attr, "0.2f">:$leakyrelu_alpha
   );
 
   let results = (outs
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index 1755c97..4624680 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -157,20 +157,10 @@
          "TensorFlow " # description # " type">,
     BuildableType<"getType<mlir::TF::" # name # "Type>()">;
 
-// Any tensor element type allowed in TensorFlow ops
-def TF_ElementType : Type<Or<[AnyFloat.predicate,
-                              AnySignlessInteger.predicate,
-                              AnyUnsignedInteger.predicate,
-                              AnyComplex.predicate,
-                              TF_TFDialectType.predicate]>,
-                          "tf.dtype">;
-
-// Any TensorFlow tensor type
-def TF_Tensor : TensorOf<[TF_ElementType]>;
-
 //===----------------------------------------------------------------------===//
 // Integer types
 
+// TODO(mgester) shouldn't this be SignedIntOfWidths?
 def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>;
 
 def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
@@ -191,10 +181,11 @@
 def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
 
 // Any signed integer type
+// TODO(mgester) shouldn't this be SignedIntOfWidths?
 def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
 
 // Any integer type
-def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>;
+def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
 
 // Any integer tensor types
 def TF_IntTensor : TensorOf<[TF_Int]>;
@@ -208,8 +199,8 @@
 def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">;
 
 // Any quantized type
-def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
-                              TF_Quint16]>;
+def TF_Quantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
+                              TF_Quint16], "quantized">;
 //===----------------------------------------------------------------------===//
 // Floating-point types
 
@@ -217,8 +208,10 @@
 
 def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
 
+def TF_Float : AnyTypeOf<[F16, F32, F64, BF16], "floating-point">;
+
 // Any floating-point tensor types
-def TF_FpTensor : TensorOf<[AnyFloat]>;
+def TF_FpTensor : TensorOf<[TF_Float]>;
 
 //===----------------------------------------------------------------------===//
 // Complex types
@@ -231,10 +224,9 @@
 def TF_Complex128 : Complex<F<64>>;
 def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
 
-def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128],
-                              "64/128-bit complex type">;
+def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
 
-def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>;
+def TF_ComplexTensor : TensorOf<[TF_Complex]>;
 
 //===----------------------------------------------------------------------===//
 // String/variant/resource types
@@ -249,28 +241,113 @@
 def TF_ResourceTensor : TensorOf<[TF_Resource]>;
 
 //===----------------------------------------------------------------------===//
+// Reference types
+
+// Float reference types
+def TF_F16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
+def TF_F32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
+def TF_F64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
+def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
+
+// Any float reference type
+def TF_FloatRef : AnyTypeOf<[TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_Bfloat16Ref],
+                            "floating-point reference">;
+
+// Complex reference types
+def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
+def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
+
+// Any complex reference type
+def TF_ComplexRef : AnyTypeOf<[TF_Complex64Ref, TF_Complex128Ref], "complex reference">;
+
+// Integer reference types
+def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
+def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
+def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
+def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
+
+def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
+def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
+def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
+def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
+
+// Any signed integer reference type
+def TF_SIntRef : AnyTypeOf<[TF_Int8Ref, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref],
+                           "signed integer reference">;
+
+// Any unsigned integer reference type
+def TF_UIntRef : AnyTypeOf<[TF_Uint8Ref, TF_Uint16Ref, TF_Uint32Ref,
+                            TF_Uint64Ref], "unsigned integer reference">;
+
+// Any integer reference type
+def TF_IntRef : AnyTypeOf<[TF_SIntRef, TF_UIntRef], "integer reference">;
+
+// Quantized reference types
+def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
+def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
+def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
+def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
+def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
+
+// Any quantized reference type
+def TF_QuantizedRef : AnyTypeOf<[TF_Qint8Ref, TF_Qint16Ref, TF_Qint32Ref,
+                                 TF_Quint8Ref, TF_Quint16Ref], "quantized reference">;
+
+// Other reference types
+def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
+def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
+def TF_StringRef : TF_TensorFlowType<"StringRef", "stringref">;
+def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
+
+// Reference tensor types
+def TF_FpRefTensor : TensorOf<[TF_FloatRef]>;
+def TF_I32OrI64RefTensor : TensorOf<[TF_Int32Ref, TF_Int64Ref]>;
+
+//===----------------------------------------------------------------------===//
 // Multi-category type constraints
 
 def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
 
-def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>;
+def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32Or64]>;
 
 // Any integer or floating-point tensor types
-def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
+def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
 
-def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>;
+def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
 
-def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
+def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
 
-def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],
-                             "number">;
+def TF_Number : AnyTypeOf<[TF_Int, TF_Float, TF_Quantized, TF_Complex],
+                          "number">;
+def TF_NumberRef : AnyTypeOf<[TF_IntRef, TF_FloatRef, TF_QuantizedRef,
+                              TF_ComplexRef], "number reference">;
 
-def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
+def TF_NumberTensor : TensorOf<[TF_Number]>;
+def TF_NumberRefTensor : TensorOf<[TF_NumberRef]>;
 
-def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>;
+def TF_NumberOrStr : AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8,
+                                TF_Str]>;
 def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>;
 
 //===----------------------------------------------------------------------===//
+// Tensor and tensor element types
+
+// Bool type
+def TF_Bool : I<1>;
+
+// Any tensor element type allowed in TensorFlow ops
+// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
+def TF_ElementType : Type<Or<[TF_Float.predicate,
+                              TF_Complex.predicate,
+                              TF_Int.predicate,
+                              TF_Bool.predicate,
+                              TF_TFDialectType.predicate]>,
+                          "tf.dtype">;
+
+// Any TensorFlow tensor type
+def TF_Tensor : TensorOf<[TF_ElementType]>;
+
+//===----------------------------------------------------------------------===//
 // TensorFlow attribute definitions
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 52f828e..bc76cd3 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -217,30 +217,6 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-
-def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> {
-  let summary = "Permute input tensor from `src_format` to `dst_format`";
-
-  let description = [{
-Input tensor must be a vector of size 4, or a 4x2 tensor.
-  }];
-
-  let arguments = (ins
-    TF_I32OrI64Tensor:$x,
-
-    DefaultValuedAttr<StrAttr, "NHWC">:$src_format,
-    DefaultValuedAttr<StrAttr, "NCHW">:$dst_format
-  );
-
-  let results = (outs
-    TF_I32OrI64Tensor:$y
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-
-  let verifier = [{ return Verify(*this); }];
-}
-
 def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> {
   let summary = "Creates and returns an empty tensor list.";
 
@@ -392,38 +368,8 @@
   let verifier = [{
     return Verify(*this);
   }];
-}
 
-def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
-  let summary = "Computes the mean of elements across dimensions of a tensor.";
-
-  let description = [{
-Reduces `input` along the dimensions given in `axis`. Unless
-`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-`axis`. If `keep_dims` is true, the reduced dimensions are
-retained with length 1.
-  }];
-
-  let arguments = (ins
-    TF_NumberTensor:$input,
-    TF_I32OrI64Tensor:$reduction_indices,
-
-    DefaultValuedAttr<BoolAttr, "false">:$keep_dims
-  );
-
-  let results = (outs
-    TF_NumberTensor:$output
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-  TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
-
-  let extraClassDeclaration = [{
-    // TF_FoldOperandsTransposeInterface:
-    SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
-    SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
-    LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
-  }];
+  let hasCanonicalizer = 1;
 }
 
 def TF_LegacyCallOp : TF_Op<"LegacyCall",
@@ -624,36 +570,6 @@
   DerivedAttr shape = TF_DerivedResultShapeAttr;
 }
 
-def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> {
-  let summary = [{
-SparseMatMul is MatMul with hints on the sparseness of the matrices.
-  }];
-
-  let description = [{
-Similar to MatMul, with a_is_sparse and b_is_sparse indicating whether a and b
-are sparse matrices.
-  }];
-
-  let arguments = (ins
-    TensorOf<[BF16, F32]>:$a,
-    TensorOf<[BF16, F32]>:$b,
-
-    DefaultValuedAttr<BoolAttr, "true">:$a_is_sparse,
-    DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse,
-
-    DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
-    DefaultValuedAttr<BoolAttr, "false">:$transpose_b
-  );
-
-  let results = (outs
-    TensorOf<[F32]>:$product
-  );
-
-  TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
-  TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>;
-}
-
-
 def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
                                          [CallOpInterface]> {
   let summary =
@@ -884,45 +800,6 @@
     TF_DerivedOperandOrResultHandleShapeAttr<"resource">;
 }
 
-// Not generated because it begins with an underscore, which isn't allowed by
-// the C++ standard.
-def TF_FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> {
-  let summary = "Internal FusedBatchNorm operation: reserved for internal use";
-
-  let description = [{
- Do not invoke this operator directly in Python. A fusion optimization is
- expected to create these operators.
-  }];
-
-  let arguments = (ins
-    TensorOf<[F16, F32]>:$x,
-    F32Tensor:$scale,
-    F32Tensor:$offset,
-    F32Tensor:$mean,
-    F32Tensor:$variance,
-    Variadic<TensorOf<[F16, F32]>>:$side_input,
-
-    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
-    DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
-    DefaultValuedAttr<StrAttr, "Identity">:$activation_mode,
-    DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
-    DefaultValuedAttr<BoolAttr, "true">:$is_training
-  );
-
-  let results = (outs
-    TensorOf<[F16, F32]>:$y,
-    F32Tensor:$batch_mean,
-    F32Tensor:$batch_variance,
-    F32Tensor:$reserve_space_1,
-    F32Tensor:$reserve_space_2,
-    F32Tensor:$reserve_space_3
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-  TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
-  TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>;
-}
-
 // Multiple variadic operands with different sizes are not supported by the
 // dialect generator, so we manually added the op.
 def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> {
@@ -1272,36 +1149,6 @@
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 }
 
-def TF_StringToHashBucketFastOp : TF_Op<"StringToHashBucketFast", [NoSideEffect]> {
-  let summary = [{
-Converts each string in the input Tensor to its hash mod by a number of buckets.
-  }];
-
-  let description = [{
-The hash function is deterministic on the content of the string within the
-process and will never change. However, it is not suitable for cryptography.
-This function may be used when CPU time is scarce and inputs are trusted or
-unimportant. There is a risk of adversaries constructing inputs that all hash
-to the same bucket. To prevent this problem, use a strong hash function with
-`tf.string_to_hash_bucket_strong`.
-
-Examples:
-
->>> tf.strings.to_hash_bucket_fast(["Hello", "TensorFlow", "2.x"], 3).numpy()
-array([0, 2, 2])
-  }];
-
-  let arguments = (ins
-    TF_StrTensor:$input,
-
-    Confined<I64Attr, [IntMinValue<1>]>:$num_buckets
-  );
-
-  let results = (outs
-    I64Tensor:$output
-  );
-}
-
 def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
   let summary = "Calls a function placed on a specified TPU device.";
 
@@ -1336,63 +1183,6 @@
   let verifier = [{ return VerifyPartitionedCall(*this); }];
 }
 
-class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
-  let summary = "Batch normalization.";
-
-  let description = [{
-Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
-The size of 1D Tensors matches the dimension C of the 4D Tensors.
-  }];
-
-  let arguments = (ins
-    TensorOf<[BF16, F16, F32]>:$x,
-    F32Tensor:$scale,
-    F32Tensor:$offset,
-    F32Tensor:$mean,
-    F32Tensor:$variance,
-
-    DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
-    DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
-    DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
-    DefaultValuedAttr<BoolAttr, "true">:$is_training
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-  TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
-
-  let extraClassDeclaration = [{
-    // TF_FoldOperandsTransposeInterface:
-    SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
-    SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
-    LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
-
-    // TF_LayoutSensitiveInterface:
-    StringRef GetOptimalLayout(const RuntimeDevices& devices);
-    LogicalResult UpdateDataFormat(StringRef data_format);
-  }];
-}
-
-def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
-  let results = (outs
-    TensorOf<[BF16, F16, F32]>:$y,
-    F32Tensor:$batch_mean,
-    F32Tensor:$batch_variance,
-    F32Tensor:$reserve_space_1,
-    F32Tensor:$reserve_space_2
-  );
-}
-
-def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
-  let results = (outs
-    TensorOf<[BF16, F16, F32]>:$y,
-    F32Tensor:$batch_mean,
-    F32Tensor:$batch_variance,
-    F32Tensor:$reserve_space_1,
-    F32Tensor:$reserve_space_2,
-    F32Tensor:$reserve_space_3
-  );
-}
-
 def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> {
   let summary = [{
 Batches all the inputs tensors to the computation done by the function.
@@ -1464,4 +1254,98 @@
   TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
 }
 
+def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
+                 WithBroadcastableBinOpBuilder {
+  let summary = "Returns x + y element-wise.";
+
+  let description = [{
+*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+  }];
+
+  let arguments = (ins
+    TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$x,
+    TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$y
+  );
+
+  let results = (outs
+    TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$z
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+
+  let hasCanonicalizer = 1;
+
+  let hasFolder = 1;
+}
+
+def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
+                    WithBroadcastableBinOpBuilder {
+  let summary = "Returns 0 if the denominator is zero.";
+
+  let description = [{
+*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+  }];
+
+  let arguments = (ins
+    TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$x,
+    TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$y
+  );
+
+  let results = (outs
+    TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$z
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
+def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
+                   WithBroadcastableBinOpBuilder {
+  let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise.";
+
+  let description = [{
+*NOTE*: `Maximum` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+  }];
+
+  let arguments = (ins
+    TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$x,
+    TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$y
+  );
+
+  let results = (outs
+    TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$z
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
+def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>,
+                   WithBroadcastableBinOpBuilder {
+  let summary = "Returns x / y element-wise for real types.";
+
+  let description = [{
+If `x` and `y` are reals, this will return the floating-point division.
+
+*NOTE*: `Div` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+  }];
+
+  let arguments = (ins
+    TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$x,
+    TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$y
+  );
+
+  let results = (outs
+    TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$z
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+
+  let hasCanonicalizer = 1;
+
+  let hasFolder = 1;
+}
+
 #endif // TF_OPS
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
index 4104428..b465c1d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc
@@ -1935,6 +1935,7 @@
 // IfOp canonicalization.
 //===----------------------------------------------------------------------===//
 
+namespace {
 class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
  public:
   explicit FoldConstantIfOp(MLIRContext *context)
@@ -1966,7 +1967,7 @@
   auto rewrite = [&](auto op_type) {
     auto empty = rewriter.getStringAttr("");
     auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
-        op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
+        op.getLoc(), op.getResultTypes(), op.input(), func,
         /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
     CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
     rewriter.replaceOp(op, call_op.getResults());
@@ -1979,6 +1980,7 @@
 
   return success();
 }
+}  // anonymous namespace
 
 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                        MLIRContext *context) {
@@ -1997,6 +1999,61 @@
   return success();
 }
 
+namespace {
+class FoldConstantIfRegionOp : public OpRewritePattern<TF::IfRegionOp> {
+ public:
+  explicit FoldConstantIfRegionOp(MLIRContext *context)
+      : OpRewritePattern<TF::IfRegionOp>(context) {}
+  LogicalResult matchAndRewrite(TF::IfRegionOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+LogicalResult FoldConstantIfRegionOp::matchAndRewrite(
+    TF::IfRegionOp op, PatternRewriter &rewriter) const {
+  // Extract the constant cond value.
+  DenseIntElementsAttr cond_attr;
+  if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
+
+  // IfRegion condition should always be a scalar. Select the region to fold to.
+  bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
+  Region &region = cond ? op.then_branch() : op.else_branch();
+
+  // If the IfRegion is stateless but the region being inlined itself is not
+  // stateless, then inlining the region could cause a loss of information.
+  // However, its probably better to fold the IfRegion instead of having the
+  // dead branch stay.
+
+  // Inline the region in place of the IfRegion op, and forward the yield
+  // inputs to the IfRegion op results. This is possible only if the yield
+  // types match the result types.
+  auto yield = cast<YieldOp>(region.front().getTerminator());
+  auto updated_results = llvm::to_vector<4>(yield.getOperands());
+
+  // If the yield types do not match the IfRegion result types, add appropriate
+  // casts.
+  rewriter.setInsertionPoint(yield);
+  for (auto it : llvm::zip(op.getResultTypes(), updated_results)) {
+    auto &updated_result = std::get<1>(it);
+    Type result_type = std::get<0>(it);
+    if (result_type != updated_result.getType()) {
+      updated_result =
+          rewriter.create<TF::CastOp>(op.getLoc(), result_type, updated_result,
+                                      /*Truncate=*/rewriter.getBoolAttr(false));
+    }
+  }
+  // Inline the region into the block containing the IfRegion.
+  rewriter.mergeBlockBefore(&region.front(), op);
+  rewriter.eraseOp(yield);
+  rewriter.replaceOp(op, updated_results);
+  return success();
+}
+}  // anonymous namespace
+
+void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                             MLIRContext *context) {
+  results.insert<FoldConstantIfRegionOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // InvertOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 925a2af..52a2cee 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -1144,6 +1144,13 @@
 }
 
 //===----------------------------------------------------------------------===//
+// SpaceToBatchNDOp
+//===----------------------------------------------------------------------===//
+
+// TODO(b/157475606): Add Verify(SpaceToBatchNDOp)
+// TODO(b/157475606): Add SpaceToBatchNDOp::inferReturnTypes
+
+//===----------------------------------------------------------------------===//
 // SparseSoftmaxCrossEntropyWithLogitsOp
 //===----------------------------------------------------------------------===//
 
@@ -1777,6 +1784,87 @@
 }
 
 //===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
+// Verifies that,
+//
+// - input has at least rank 1
+// - multiples is rank 1
+// - multiples.size() == input.rank()
+// - input.rank() == output.rank()
+// - Elements in multiples are non-negative
+// - input.shape[i] * multiples[i] == output.shape[i]
+//   for i in [0, input.rank() - 1]
+
+static LogicalResult Verify(TileOp op) {
+  auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
+  auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
+  auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
+
+  if (multiples_type && multiples_type.getRank() != 1) {
+    return op.emitOpError() << "expected multiples to be rank 1, got rank = "
+                            << multiples_type.getRank();
+  }
+
+  if (input_type && multiples_type && multiples_type.hasStaticShape() &&
+      (input_type.getRank() != multiples_type.getNumElements() ||
+       (input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
+    return op.emitOpError()
+           << "expected size of multiples equal to rank of input"
+           << ", got multiples of size " << multiples_type.getNumElements()
+           << ", and input of rank " << input_type.getRank();
+  }
+
+  if (input_type && output_type) {
+    if (input_type.getRank() != output_type.getRank()) {
+      return op.emitOpError()
+             << "expected rank of input to equal to rank of output"
+             << ", got input of rank " << input_type.getRank()
+             << ", and output of rank " << output_type.getRank();
+    }
+
+    DenseIntElementsAttr multiples_attr;
+    if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
+      for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
+        const int64_t input_dim = input_type.getDimSize(i);
+        const int64_t output_dim = output_type.getDimSize(i);
+        const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
+
+        if (m < 0) {
+          return op.emitOpError()
+                 << "expected multiples to be non-negative, got "
+                 << "multiples[" << i << "] = " << m;
+        }
+
+        if (!ShapedType::isDynamic(input_dim) &&
+            !ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
+          return op.emitOpError()
+                 << "requires input.shape[" << i << "] (" << input_dim << ")"
+                 << " * " << m << " to be equal to "
+                 << "output.shape[" << i << "] (" << output_dim << ")";
+        }
+      }
+    }
+  }
+
+  return success();
+}
+
+OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
+  DenseIntElementsAttr multiples_attr;
+  if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
+    // Return input directly when multiples are all ones,
+    // regardless what input is.
+    if (multiples_attr.isSplat() &&
+        multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
+      return input();
+    }
+  }
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
 // TopKV2Op
 //===----------------------------------------------------------------------===//
 
@@ -1796,26 +1884,57 @@
 //===----------------------------------------------------------------------===//
 
 namespace {
-// If the input to ToBoolOp is a `tensor<i1>`, then the ToBoolOp is an identity
-// function and can be removed.
-class ToBoolOfZeroDBoolTensor : public OpRewritePattern<ToBoolOp> {
+// If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
+// into an identity or an equality comparison.
+class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
   using OpRewritePattern<ToBoolOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ToBoolOp op,
                                 PatternRewriter &rewriter) const override {
-    if (auto type = op.getOperand().getType().dyn_cast<RankedTensorType>()) {
-      if (type.getRank() == 0 && type.getElementType().isInteger(1)) {
-        rewriter.replaceOp(op, op.getOperand());
-        return success();
-      }
+    auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
+    // If the input is an unranked tensor, cannpt rewrite.
+    if (!type) return failure();
+
+    // Expected return type of the ToBool operation.
+    auto result_type = op.getResult().getType().cast<RankedTensorType>();
+
+    // If input is already a tensor<i1>, it can be folded into an identity.
+    if (type == result_type) {
+      rewriter.replaceOp(op, op.getOperand());
+      return success();
     }
-    return failure();
+
+    if (type.getRank() == 0) {
+      // If the input is a scalar tensor, the ToBool can be expanded to
+      // element != 0 (for numerical values) or element == empty (for string).
+      Type element_type = type.getElementType();
+      Attribute zero_attr;
+      if (element_type.isIntOrFloat())
+        zero_attr = rewriter.getZeroAttr(type);
+      else if (element_type.isa<TF::StringType>())
+        zero_attr = DenseStringElementsAttr::get(type, {""});
+
+      if (!zero_attr) return failure();
+
+      auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
+      rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
+          op, result_type, op.getOperand(), zero_const, false);
+    } else {
+      // If the input is a non-scalar ranked tensor, ToBool can be expanded
+      // to numElements != 0. numElements will be 0 iff one of the dimensions is
+      // zero.
+      bool any_zero =
+          llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
+      rewriter.replaceOpWithNewOp<TF::ConstOp>(
+          op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
+    }
+    return success();
   }
 };
 }  // namespace
 
 void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<ToBoolOfZeroDBoolTensor>(context);
+  results.insert<ToBoolOfRankedTensor>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1908,11 +2027,9 @@
 namespace {
 
 OpFoldResult FoldIdentityTranspose(TransposeOp op) {
-  auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
-  if (!const_perm) return {};
-
-  auto const_value = const_perm.value();
-  const auto elements = const_value.getValues<APInt>();
+  DenseIntElementsAttr perm;
+  if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
+  const auto elements = perm.getValues<APInt>();
 
   for (auto it : llvm::enumerate(elements)) {
     if (it.index() != it.value()) return {};
@@ -1935,14 +2052,14 @@
   if (!transpose) return {};
 
   // Permutations defined by constant operations.
-  auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
-  auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
-  if (!perm0 || !perm1) return {};
+  DenseIntElementsAttr perm0;
+  DenseIntElementsAttr perm1;
+  if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
+      !matchPattern(transpose.perm(), m_Constant(&perm1)))
+    return {};
 
   // With permutation indices that cancel each other
-  auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
-  auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
-  if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
+  if (!AreCancellablePermutations(perm0, perm1)) return {};
 
   return transpose.x();
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
index 6883d03..2eaa511 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc
@@ -115,6 +115,11 @@
 TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
     : Dialect(/*name=*/"tf_saved_model", context,
               TypeID::get<TensorFlowSavedModelDialect>()) {
+  // The TensorFlow Dialect is needed in the verifier and other routines
+  // associated to this dialect. It makes little sense anyway to use the
+  // SavedModel dialect without the TensorFlow Dialect.
+  context->loadDialect<TF::TensorFlowDialect>();
+
   addOperations<
 #define GET_OP_LIST
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
index 05d34eb..6654341 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/breakup-islands.mlir
@@ -285,7 +285,7 @@
 // and certain tf_executor ops are added correctly.
 
 // CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
-// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]]
+// CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]]
 func @next_iteration_sink_control_input() {
   tf_executor.graph {
     %source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index 8c3e8dc..ff90c6f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -568,6 +568,14 @@
   return %0: tensor<*xf16>
 }
 
+// CHECK-LABEL: testTileMultiplesAllOnes
+func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %cst = constant dense <[1, 1]> : tensor<2xi32>
+  // CHECK: return %arg0
+  %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+  return %0: tensor<2x3xf32>
+}
+
 // CHECK-LABEL: testLogicalNotOfEqual
 func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
   %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
@@ -702,6 +710,15 @@
   // CHECK: return %arg0
 }
 
+// CHECK-LABEL: @identityTransposeConst
+func @identityTransposeConst(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
+  %0 = constant dense<[0, 1, 2, 3, 4]> : tensor<5xi32>
+  %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x5x6xf32>
+
+  return %1 : tensor<2x3x4x5x6xf32>
+  // CHECK: return %arg0
+}
+
 // CHECK-LABEL: @nonIdentityTranspose
 func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32> {
   %0 = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
@@ -724,6 +741,17 @@
   // CHECK: return %arg0
 }
 
+// CHECK-LABEL: @cancellableTransposeConst
+func @cancellableTransposeConst(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
+  %0 = constant dense<[0, 3, 1, 2]> : tensor<4xi32>
+  %1 = constant dense<[0, 2, 3, 1]> : tensor<4xi32>
+  %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
+  %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
+
+  return %3 : tensor<1x4x4x8xf32>
+  // CHECK: return %arg0
+}
+
 // CHECK-LABEL: @nonCancellableTranspose
 func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
   %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
@@ -742,13 +770,72 @@
   return %0 : tensor<*xf32>
 }
 
-// CHECK-LABEL: func @ToBool_0DScalar
-func @ToBool_0DScalar(%arg0: tensor<i1>) -> tensor<i1> {
+// CHECK-LABEL: func @ToBool_0DScalarI1
+func @ToBool_0DScalarI1(%arg0: tensor<i1>) -> tensor<i1> {
   // CHECK: return %arg0
   %0 = "tf.ToBool"(%arg0) : (tensor<i1>) -> tensor<i1>
   return %0 : tensor<i1>
 }
 
+// CHECK-LABEL: func @ToBool_0DScalarInt
+func @ToBool_0DScalarInt(%arg0: tensor<i32>) -> tensor<i1> {
+  // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
+  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]])
+  // CHECK: return [[NE]]
+  %0 = "tf.ToBool"(%arg0) : (tensor<i32>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: func @ToBool_0DScalarFloat
+func @ToBool_0DScalarFloat(%arg0: tensor<f32>) -> tensor<i1> {
+  // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]])
+  // CHECK: return [[NE]]
+  %0 = "tf.ToBool"(%arg0) : (tensor<f32>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: func @ToBool_0DScalarString
+func @ToBool_0DScalarString(%arg0: tensor<!tf.string>) -> tensor<i1> {
+  // CHECK: [[EmptyStr:%.*]] = "tf.Const"() {value = dense<""> : tensor<!tf.string>} : () -> tensor<!tf.string>
+  // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) {incompatible_shape_error = false} : (tensor<!tf.string>, tensor<!tf.string>) -> tensor<i1>
+  // CHECK: return [[NE]] : tensor<i1>
+  %0 = "tf.ToBool"(%arg0) : (tensor<!tf.string>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: func @ToBool_1DTensor
+func @ToBool_1DTensor(%arg0: tensor<1xf32>) -> tensor<i1> {
+  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: return [[Const]]
+  %0 = "tf.ToBool"(%arg0) : (tensor<1xf32>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: func @ToBool_1DTensorZeroDim
+func @ToBool_1DTensorZeroDim(%arg0: tensor<0xf32>) -> tensor<i1> {
+  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: return [[Const]]
+  %0 = "tf.ToBool"(%arg0) : (tensor<0xf32>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: func @ToBool_2DTensor
+func @ToBool_2DTensor(%arg0: tensor<1x5xf32>) -> tensor<i1> {
+  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: return [[Const]]
+  %0 = "tf.ToBool"(%arg0) : (tensor<1x5xf32>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: func @ToBool_2DTensorZeroDim
+func @ToBool_2DTensorZeroDim(%arg0: tensor<1x0xf32>) -> tensor<i1> {
+  // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  // CHECK: return [[Const]]
+  %0 = "tf.ToBool"(%arg0) : (tensor<1x0xf32>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
 // CHECK-LABEL: testReadVariableOpOfCast
 func @testReadVariableOpOfCast(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<8x40xf32> {
   %0 = "tf.Cast"(%arg0) : (tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<*x!tf.resource>
@@ -843,6 +930,51 @@
   return %4 : tensor<f32>
 }
 
+// CHECK-LABEL: foldIfRegion
+func @foldIfRegion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>, tensor<f32>) {
+  %false = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  %true = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+
+  // CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1)
+  %0 = "tf.IfRegion"(%true) ({
+      %true_value = "tf.Mul"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      "tf.Yield"(%true_value) : (tensor<f32>) -> ()
+    }, {
+      %false_value = "tf.Sub"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      "tf.Yield"(%false_value) : (tensor<f32>) -> ()
+    }) { is_stateless = true}: (tensor<i1>) -> tensor<f32>
+
+  // CHECK: [[Val1:%.*]] = "tf.Sub"(%arg0, %arg1)
+  %1 = "tf.IfRegion"(%false) ({
+      %true_value = "tf.Mul"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      "tf.Yield"(%true_value) : (tensor<f32>) -> ()
+    }, {
+      %false_value = "tf.Sub"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+      "tf.Yield"(%false_value) : (tensor<f32>) -> ()
+    }) { is_stateless = true}: (tensor<i1>) -> tensor<f32>
+
+  // CHECK: return [[Val0]], [[Val1]]
+  return %0, %1 : tensor<f32>, tensor<f32>
+}
+
+// CHECK-LABEL: foldIfRegionMismatchedTypes
+func @foldIfRegionMismatchedTypes(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<i1>) -> tensor<1xf32> {
+  %false = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
+  %true = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
+
+  // CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1)
+  // CHECK-NEXT: [[Cast:%.*]] = "tf.Cast"([[Val0]])
+  // CHECK-NEXT: return [[Cast]]
+  %0 = "tf.IfRegion"(%true) ({
+      %true_value = "tf.Mul"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+      "tf.Yield"(%true_value) : (tensor<?xf32>) -> ()
+    }, {
+      %false_value = "tf.Sub"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+      "tf.Yield"(%false_value) : (tensor<?xf32>) -> ()
+    }) { is_stateless = true}: (tensor<i1>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
 // CHECK-LABEL: foldCase
 func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
   %2 = constant dense<1> : tensor<i32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index f114d17..fff985e 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -492,3 +492,13 @@
   return %3 : tensor<8x10000xi32>
 }
 // LINT.ThenChange(../transforms/constant_fold.cc:folding-policy)
+
+func @fold_conv() -> tensor<1x520x520x1xf32> {
+  %0 = "tf.Const"() {value = dense<0.111111112> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32>
+  %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<1x520x520x1xf32>} : () -> tensor<1x520x520x1xf32>
+  %2 = "tf.DepthwiseConv2dNative"(%1, %0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x520x520x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x520x520x1xf32>
+  return %2 : tensor<1x520x520x1xf32>
+
+  // CHECK: tf.Const
+  // CHECK-NOT: tf.DepthwiseConv2dNative
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir
new file mode 100644
index 0000000..8250bcf
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir
@@ -0,0 +1,16 @@
+// RUN: tf-opt -tf-tensor-device-copy %s | FileCheck %s --dump-input=fail
+
+// CHECK-LABEL: func @fold_identity
+// CHECK-SAME: ([[arg0:%.*]]: tensor<2x2xf32>, [[arg1:%.*]]: tensor<2x2xf32>
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32}} {
+  func @fold_identity(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+    %0 = tf_executor.graph {
+      // CHECK: tf.MatMul
+      %outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+      // CHECK-NOT: tf.Identity
+      %outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
+      tf_executor.fetch %outputs_0 : tensor<2x2xf32>
+    }
+    return %0 : tensor<2x2xf32>
+  }
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir
index bec4818..726495f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir
@@ -220,7 +220,7 @@
     %11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor<i32>
     %12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
     %13 = tf_executor.ControlTrigger %2, %12#1, %9#1
-    tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32>
+    tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32>
     tf_executor.fetch
   }
   return
@@ -244,7 +244,7 @@
 // CHECK-NEXT:     %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]])
 // CHECK-NEXT:     tf_executor.yield %[[OP_G]] : tensor<*xi32>
 // CHECK:        %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]]
-// CHECK-NEXT:   tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]]
+// CHECK-NEXT:   tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]]
 
 
 // Test no merging took place as cycle would be formed otherwise.
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt
index e21fd90..a6b1979 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt
@@ -7,7 +7,7 @@
 # CHECK:    %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source
 # CHECK:    tf_executor.Merge {{.*}} %[[NEXTITERATION]]
 
-# CHECK:    tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]]
+# CHECK:    tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]]
 
 node {
   name: "Const"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index bd8f740..01e9940 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -215,6 +215,42 @@
   return %0 : tensor<*xf32>
 }
 
+// %input has 1 batch dimension then 2 block dimensions then 1 remainder dimension.
+// CHECK-LABEL: SpaceToBatchND
+func @SpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> {
+  // CHECK-DAG: [[I0:%.+]] = "tf.Const"() {value = dense<0> : tensor<1xi64>}
+  // CHECK-DAG: [[I1:%.+]] = "tf.Const"() {value = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: [[I2:%.+]] = "tf.Const"() {value = dense<2> : tensor<1xi64>}
+  // CHECK-DAG: [[I3:%.+]] = "tf.Const"() {value = dense<3> : tensor<1xi64>}
+  // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>}
+  // CHECK-DAG: [[AXIS:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
+  // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[AXIS]])
+  // CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
+  // CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Einsum"([[FULL_PADDINGS]]) {equation = "ij->i"}
+  // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
+  // CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
+  // CHECK-DAG: [[PADDED_SHAPE_0:%.+]] = "tf.Slice"([[PADDED_SHAPE]], [[I0]], [[I1]])
+  // CHECK-DAG: [[PADDED_SHAPE_1:%.+]] = "tf.Slice"([[PADDED_SHAPE]], [[I1]], [[I1]])
+  // CHECK-DAG: [[PADDED_SHAPE_2:%.+]] = "tf.Slice"([[PADDED_SHAPE]], [[I2]], [[I1]])
+  // CHECK-DAG: [[PADDED_SHAPE_3:%.+]] = "tf.Slice"([[PADDED_SHAPE]], [[I3]], [[I1]])
+  // CHECK-DAG: [[BLOCK_SHAPE_0:%.+]] = "tf.Slice"(%arg1, [[I0]], [[I1]])
+  // CHECK-DAG: [[BLOCK_SHAPE_1:%.+]] = "tf.Slice"(%arg1, [[I1]], [[I1]])
+  // CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_1]], [[BLOCK_SHAPE_0]])
+  // CHECK-DAG: [[OUTER_SHAPE_1:%.+]] = "tf.Div"([[PADDED_SHAPE_2]], [[BLOCK_SHAPE_1]])
+  // CHECK-DAG: [[RESHAPED_SHAPE:%.+]] = "tf.ConcatV2"([[PADDED_SHAPE_0]], [[OUTER_SHAPE_0]], [[BLOCK_SHAPE_0]], [[OUTER_SHAPE_1]], [[BLOCK_SHAPE_1]], [[PADDED_SHAPE_3]], [[AXIS]])
+  // CHECK-DAG: [[PERMUTATION:%.+]] = "tf.Const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi64>}
+  // CHECK-DAG: [[OUTPUT_BATCH_PART:%.+]] = "tf.Mul"([[PADDED_SHAPE_0]], [[BLOCK_SHAPE_0]])
+  // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART]], [[BLOCK_SHAPE_1]])
+  // CHECK-DAG: [[OUTPUT_SHAPE:%.+]] = "tf.ConcatV2"([[OUTPUT_BATCH]], [[OUTER_SHAPE_0]], [[OUTER_SHAPE_1]], [[PADDED_SHAPE_3]], [[AXIS]])
+  // CHECK-DAG: [[RESHAPED:%.+]] = "tf.Reshape"([[PADDED]], [[RESHAPED_SHAPE]])
+  // CHECK-DAG: [[PERMUTED:%.+]] = "tf.Transpose"([[RESHAPED]], [[PERMUTATION]])
+  // CHECK-DAG: [[RESULT:%.+]] = "tf.Reshape"([[PERMUTED]], [[OUTPUT_SHAPE]])
+  // CHECK-DAG: return [[RESULT]]
+  %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // CHECK-LABEL: SoftmaxCrossEntropyWithLogits
 // CHECK-SAME: %[[FEATURES:.*]]: tensor<2x3xf32>, %[[LABELS:.*]]: tensor<2x3xf32>
 func @SoftmaxCrossEntropyWithLogits(%features: tensor<2x3xf32>, %labels: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) {
@@ -479,14 +515,38 @@
   return %0 : tensor<1x2xf32>
 }
 
-// CHECK-LABEL: @Reciprocal
-func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+// CHECK-LABEL: @Reciprocal_i32
+func @Reciprocal_i32(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi32>
+  %0 = "tf.Reciprocal"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
+// CHECK-LABEL: @Reciprocal_f32
+func @Reciprocal_f32(%arg0: tensor<*xf32>) -> tensor<*xf32> {
   // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
   // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
   %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
 
+// CHECK-LABEL: @Reciprocal_complexf32
+func @Reciprocal_complexf32(%arg0: tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>> {
+  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
+  // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<complex<f32>>, tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
+  %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex<f32>>) -> tensor<*xcomplex<f32>>
+  return %0 : tensor<*xcomplex<f32>>
+}
+
+// CHECK-LABEL: @Reciprocal_complexf64
+func @Reciprocal_complexf64(%arg0: tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f64>> {
+  // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>} : () -> tensor<complex<f64>>
+  // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<complex<f64>>, tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f64>>
+  %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f64>>
+  return %0 : tensor<*xcomplex<f64>>
+}
+
 // CHECK-LABEL: @ScatterNd
 func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
   // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
index df2add2..dc99d9d 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir
@@ -137,7 +137,7 @@
     // CHECK: "tf.IfRegion"
     // CHECK: "tf.StringToNumber"
     // CHECK-NOT: _xla_outside_compilation
-    // CHECK: _xla_outside_compilation = "auto", is_stateless = true
+    // CHECK: _xla_outside_compilation = "auto1", is_stateless = true
     %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
     %2 = "tf.IfRegion"(%arg0) ( {
       %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
@@ -166,7 +166,7 @@
       %3 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
       "tf.Yield"(%3) : (tensor<f32>) -> ()
      },  {
-      // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
+      // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf.string>}
       // CHECK-NEXT: "tf.StringToNumber"
       // CHECK-SAME: _xla_outside_compilation
       %4 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
@@ -198,7 +198,7 @@
        // CHECK-NOT: _xla_outside_compilation
        %4 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
        %5 = "tf.IfRegion"(%4)({
-         // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor<!tf.string>}
+         // CHECK: "tf.Const"() {_xla_outside_compilation = "auto0", value = dense<"1.0"> : tensor<!tf.string>}
          // CHECK-NEXT: "tf.StringToNumber"
          // CHECK-SAME: _xla_outside_compilation
          %6 = "tf.Const"() {value = dense<"1.0"> : tensor<!tf.string>} : () -> tensor<!tf.string>
@@ -229,7 +229,7 @@
     // CHECK-NOT: _xla_outside_compilation
     // CHECK: "tf.WhileRegion"
     // CHECK: "tf.StringToNumber"
-    // CHECK: _xla_outside_compilation = "auto", is_stateless = true
+    // CHECK: _xla_outside_compilation = "auto1", is_stateless = true
     %1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
     %2:2 = "tf.WhileRegion"(%1, %arg0) ( {
       ^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
index 3e61357..26df602 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir
@@ -530,6 +530,21 @@
     return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
   }
 
+  // CHECK-LABEL: infer_device_cluster
+  func @infer_device_cluster(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) {
+    %0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
+    %1 = "tf_device.cluster"() ({
+      %2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32>
+      tf_device.return %2 : tensor<1x8x2xf32>
+    // CHECK: () -> tensor<1x8x2xf32>
+    }) : () -> tensor<*xf32>
+    // CHECK: "tf.Cast"(%{{.*}}) {Truncate = false} : (tensor<1x8x2xf32>) -> tensor<*xf32>
+    // CHECK: (tensor<i32>, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>)
+    %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor<i32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
+    %4 = addf %1, %1 : tensor<*xf32>
+    return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32>
+  }
+
   // CHECK-LABEL: func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<1xi32>
   func @tensor_cast(%arg0: tensor<1xi32>) -> tensor<*xi32> {
    // CHECK: %[[RESULT:.*]] = tensor_cast
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index 9a8d97e..30a763b 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -3468,3 +3468,85 @@
   %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
   return %0 : tensor<8x16xf32>
 }
+
+// -----
+
+func @testTile(%arg0: tensor<2x3x?xf32>) {
+  %cst = constant dense <[2, 3, 4]> : tensor<3xi32>
+  %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32>
+  return
+}
+
+// -----
+
+func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) {
+  // expected-error @+1 {{expected multiples to be rank 1, got rank = 2}}
+  %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32>
+  return
+}
+
+// -----
+
+func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) {
+  // expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}}
+  %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32>
+  return
+}
+
+// -----
+
+func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) {
+  // expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}}
+  %0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32>
+  return
+}
+
+// -----
+
+func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) {
+  %cst = constant dense <[-1, 1]> : tensor<2xi32>
+  // expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}}
+  %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+  return
+}
+
+// -----
+
+func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) {
+  %cst = constant dense <[2, 3]> : tensor<2xi32>
+  // expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}}
+  %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32>
+  return
+}
+
+// -----
+
+// Test reference variable support for some ops (no errors expected)
+
+// CHECK-LABEL: @testMaximumWithRef
+func @testMaximumWithRef(%arg0: tensor<!tf.f32ref>, %arg1: tensor<f32>) -> tensor<f32> {
+  // CHECK: tf.Maximum
+  %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<!tf.f32ref>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: @testAddV2WithRef
+func @testAddV2WithRef(%arg0: tensor<!tf.int16ref>, %arg1: tensor<i16>) -> tensor<i16> {
+  // CHECK: tf.AddV2
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<!tf.int16ref>, tensor<i16>) -> tensor<i16>
+  return %0 : tensor<i16>
+}
+
+// CHECK-LABEL: @testRealDivWithRef
+func @testRealDivWithRef(%arg0: tensor<f64>, %arg1: tensor<!tf.f64ref>) -> tensor<f64> {
+  // CHECK: tf.RealDivOp
+  %0 = "tf.RealDivOp"(%arg0, %arg1) : (tensor<f64>, tensor<!tf.f64ref>) -> tensor<f64>
+  return %0 : tensor<f64>
+}
+
+// CHECK-LABEL: @testDivNoNanWithRef
+func @testDivNoNanWithRef(%arg0: tensor<f32>, %arg1: tensor<!tf.f32ref>) -> tensor<f32> {
+  // CHECK: tf.DivNoNanOp
+  %0 = "tf.DivNoNanOp"(%arg0, %arg1) : (tensor<f32>, tensor<!tf.f32ref>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
index 1e53788..23a8e90 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir
@@ -433,7 +433,7 @@
     %1:3 = tf_executor.NextIteration.Source : tensor<*xf32>
     tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32>
 // CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
-// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32>
+// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32>
     tf_executor.fetch %1#0 : tensor<*xf32>
   }
   return %0 : tensor<*xf32>
@@ -445,7 +445,7 @@
     %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
     tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
 // CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
-// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
+// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
     tf_executor.fetch %1#0 : tensor<*xf32>
   }
   return %0 : tensor<*xf32>
@@ -457,9 +457,9 @@
     %1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
     %2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32>
     %3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
-    tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
+    tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32>
 // CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
-// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32>
+// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32>
     tf_executor.fetch %3#0 : tensor<*xf32>
   }
   return %0 : tensor<*xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
index 9467f89..7b670cd 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir
@@ -11,9 +11,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
   // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
   // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
@@ -31,7 +31,7 @@
   // CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1)
   %execute = "tf_device.launch"() ( {
     %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %3 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   return %execute : tensor<i32>
@@ -49,9 +49,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   // CHECK-NOT: "tf.TPUGetLayoutOp"
   // CHECK-NOT: "tf.TPUCopyWithLayout"
   %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
@@ -62,13 +62,13 @@
   }) {device = "/device:CPU:0"} : () -> ()
   %execute0 = "tf_device.launch"() ( {
     %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %3 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   %4:2 = "tf._UnKnownOp_"() : () -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
   %execute1 = "tf_device.launch"() ( {
     %5 = "tf.TPUExecute"(%4#0, %4#1, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %5 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   return %execute1 : tensor<i32>
@@ -85,9 +85,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   // CHECK-NOT: "tf.TPUGetLayoutOp"
   // CHECK-NOT: "tf.TPUCopyWithLayout"
   %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:TPU:0"}
@@ -98,7 +98,7 @@
   }) {device = "/device:CPU:0"} : () -> ()
   %execute = "tf_device.launch"() ( {
     %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %3 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   return %execute : tensor<i32>
@@ -116,9 +116,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   // CHECK-NOT: "tf.TPUGetLayoutOp"
   // CHECK-NOT: "tf.TPUCopyWithLayout"
   %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
@@ -129,7 +129,7 @@
   }) {device = "/device:CPU:0"} : () -> ()
   %execute = "tf_device.launch"() ( {
     %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %3 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   return %execute : tensor<i32>
@@ -148,9 +148,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   %id1 = "tf.Identity"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>)
   %id2 = "tf.Identity"(%id1) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>)
   // CHECK-NOT: "tf.TPUGetLayoutOp"
@@ -163,7 +163,7 @@
   }) {device = "/device:CPU:0"} : () -> ()
   %execute = "tf_device.launch"() ( {
     %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %3 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   return %execute : tensor<i32>
@@ -181,9 +181,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   %var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<*x!tf.resource>
   // CHECK-NOT: "tf.TPUGetLayoutOp"
   // CHECK-NOT: "tf.TPUCopyWithLayout"
@@ -195,7 +195,7 @@
   }) {device = "/device:CPU:0"} : () -> ()
   %execute = "tf_device.launch"() ( {
     %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %3 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   return %execute : tensor<i32>
@@ -212,9 +212,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   // CHECK-NOT: "tf.TPUGetLayoutOp"
   // CHECK-NOT: "tf.TPUCopyWithLayout"
   %2 = "tf._Unknown_"() : () -> tensor<3x3x1x32xf32>
@@ -224,7 +224,7 @@
   }) {device = "/device:CPU:0"} : () -> ()
   %execute = "tf_device.launch"() ( {
     %3 = "tf.TPUExecute"(%arg0, %2, %compile#1)
-      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
     tf_device.return %3 : tensor<i32>
   }) {device = "/device:TPU:0"} : () -> tensor<i32>
   return %execute : tensor<i32>
@@ -246,9 +246,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
   // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
   // CHECK: %[[ITER1:.*]]:2 = "tf.IteratorGetNext"
@@ -267,7 +267,7 @@
       {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} {
     // CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1)
     %execute = "tf_device.launch"() ( {
-      %4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      %4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
       tf_device.return %4 : tensor<i32>
     }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
     tf_device.return %execute : tensor<i32>
@@ -286,9 +286,9 @@
       NumDynamicShapes = 0 : i64,
       // The metadata encodes 2 parameter and two return values.
       metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+      mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
   // CHECK-NOT: "tf.TPUGetLayoutOp"
   // CHECK-NOT: "tf.TPUCopyWithLayout"
   "tf_device.launch"() ( {
@@ -300,7 +300,7 @@
     %2:2 = "tf.IteratorGetNext"(%r0)
       : (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
     %execute = "tf_device.launch"() ( {
-      %4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
+      %4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor<i32>
       tf_device.return %4 : tensor<i32>
     }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
     tf_device.return %execute : tensor<i32>
@@ -330,9 +330,9 @@
   // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch"
   // CHECK-NEXT: "tf._TPUCompileMlir"()
   %compile:3 = "tf_device.launch"() ( {
-    %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1, %1#2 : tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>)
+    %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1, %1#2 : tensor<!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>)
   // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
   // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false}
   // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
@@ -351,7 +351,7 @@
     // CHECK-NEXT: tf_device.return
     // CHECK-NEXT: device = "/device:TPU:0"
     "tf_device.launch"() ( {
-      "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor<!tf.string>) -> ()
+      "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor<2x!tf.string>) -> ()
       tf_device.return
     }) {device = "/device:TPU:0"} : () -> ()
     tf_device.return
@@ -364,7 +364,7 @@
     // CHECK-NEXT: tf_device.return
     // CHECK-NEXT: device = "/device:TPU:1"
     "tf_device.launch"() ( {
-      "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor<!tf.string>) -> ()
+      "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor<2x!tf.string>) -> ()
       tf_device.return
     }) {device = "/device:TPU:1"} : () -> ()
     tf_device.return
@@ -396,9 +396,9 @@
   // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch"
   // CHECK-NEXT: "tf._TPUCompileMlir"()
   %compile:3 = "tf_device.launch"() ( {
-    %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>)
-    tf_device.return %1#0, %1#1, %1#2 : tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>
-  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>, tensor<!tf.string>)
+    %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>)
+    tf_device.return %1#0, %1#1, %1#2 : tensor<!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>
+  }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>)
   // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
   // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false}
   // CHECK-DAG: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"(%[[ARG0]])
@@ -423,7 +423,7 @@
       // CHECK-NEXT: tf_device.return
       // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
       "tf_device.launch"() ( {
-        "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor<!tf.string>) -> ()
+        "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor<2x!tf.string>) -> ()
         tf_device.return
       }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
       tf_device.return
@@ -433,7 +433,7 @@
       // CHECK-NEXT: tf_device.return
       // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1"
       "tf_device.launch"() ( {
-        "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor<!tf.string>) -> ()
+        "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor<2x!tf.string>) -> ()
         tf_device.return
       }) {device = "TPU_REPLICATED_CORE_1"} : () -> ()
       tf_device.return
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir
new file mode 100644
index 0000000..a505a4e
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-resource-read-for-write.mlir
@@ -0,0 +1,64 @@
+// RUN: tf-opt -tf-tpu-resource-read-for-write %s | FileCheck %s --dump-input=always
+
+// CHECK-LABEL: func @write_only_resource
+// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>, [[ARG1:%.*]]: tensor<f32>, [[ARG2:%.*]]: tensor<*x!tf.resource<tensor<i32>>>)
+func @write_only_resource(%arg0: tensor<i32>, %arg1: tensor<f32>, %arg2: tensor<*x!tf.resource<tensor<i32>>>) {
+  // CHECK-NEXT: [[READ:%.*]] = "tf.ReadVariableOp"([[ARG2]])
+  // CHECK-NEXT: [[CLUSTER:%.*]]:2 = "tf_device.cluster_func"([[ARG0]], [[ARG1]], [[READ]])
+  // CHECK-SAME: _tpu_replicate = "write"
+  %0:2 = "tf_device.cluster_func"(%arg0, %arg1) {_tpu_replicate = "write", func = @write_func} : (tensor<i32>, tensor<f32>) -> (tensor<f32>, tensor<i32>)
+  // CHECK-NEXT: "tf.AssignVariableOp"([[ARG2]], [[CLUSTER]]#1)
+  "tf.AssignVariableOp"(%arg2, %0#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
+  // CHECK-NEXT: return
+  return
+}
+
+// CHECK-LABEL: func @write_func
+// CHECK-SAME: ({{%.*}}: tensor<i32>, {{%.*}}: tensor<f32>, {{%.*}}: tensor<i32>) -> (tensor<f32>, tensor<i32>)
+func @write_func(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i32>) {
+  return %arg1, %arg0 : tensor<f32>, tensor<i32>
+}
+
+// CHECK-LABEL: func @read_write_resource
+func @read_write_resource(%arg0: tensor<i32>, %arg1: tensor<f32>, %arg2: tensor<*x!tf.resource<tensor<i32>>>) {
+  // CHECK-COUNT-1: tf.ReadVariableOp
+  %0 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
+  %1:2 = "tf_device.cluster_func"(%arg0, %arg1, %0) {_tpu_replicate = "read_write", func = @read_write_func} : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<f32>, tensor<i32>)
+  "tf.AssignVariableOp"(%arg2, %1#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
+  return
+}
+
+// CHECK-LABEL: func @read_write_func
+// CHECK-SAME: ({{%.*}}: tensor<i32>, {{%.*}}: tensor<f32>) -> (tensor<f32>, tensor<i32>)
+func @read_write_func(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i32>) {
+  return %arg1, %arg0 : tensor<f32>, tensor<i32>
+}
+
+// CHECK-LABEL: func @multiple_write_resource
+func @multiple_write_resource(%arg0: tensor<i32>, %arg1: tensor<*x!tf.resource<tensor<i32>>>) {
+  // CHECK-NOT: tf.ReadVariableOp
+  %0:2 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_write", func = @multiple_write_func} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
+  "tf.AssignVariableOp"(%arg1, %0#0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
+  "tf.AssignVariableOp"(%arg1, %0#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
+  return
+}
+
+// CHECK-LABEL: func @multiple_write_func
+// CHECK-SAME: ({{%.*}}: tensor<i32>) -> (tensor<i32>, tensor<i32>)
+func @multiple_write_func(%arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
+  return %arg0, %arg0 : tensor<i32>, tensor<i32>
+}
+
+// CHECK-LABEL: func @multiple_result_user
+func @multiple_result_user(%arg0: tensor<i32>, %arg1: tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32> {
+  // CHECK-NOT: tf.ReadVariableOp
+  %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_uses", func = @multiple_result_user_func} : (tensor<i32>) -> tensor<i32>
+  "tf.AssignVariableOp"(%arg1, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
+  return %0 : tensor<i32>
+}
+
+// CHECK-LABEL: func @multiple_result_user_func
+// CHECK-SAME: ({{%.*}}: tensor<i32>) -> tensor<i32>
+func @multiple_result_user_func(%arg0: tensor<i32>) -> tensor<i32> {
+  return %arg0 : tensor<i32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
index 1e308b4..277e4a8 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir
@@ -61,9 +61,9 @@
         NumDynamicShapes = 0 : i64,
         // The metadata encodes 2 parameter and two return values.
         metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
-    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
+    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
     "tf_device.launch"() ( {
       "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
       tf_device.return
@@ -86,7 +86,7 @@
       "tf_device.launch"() ( {
         "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
               {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
-                : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
+                : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
         tf_device.return
       }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
       %ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -153,9 +153,9 @@
         NumDynamicShapes = 0 : i64,
         // The metadata encodes 2 parameter and two return values.
         metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
-    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
+    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
     "tf_device.launch"() ( {
       "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
       tf_device.return
@@ -173,7 +173,7 @@
         "tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
               {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
                 : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
-                   tensor<*x!tf.resource<tensor<f32>>>, tensor<!tf.string>) -> ()
+                   tensor<*x!tf.resource<tensor<f32>>>, tensor<2x!tf.string>) -> ()
         tf_device.return
       }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
       tf_device.return
@@ -239,9 +239,9 @@
         NumDynamicShapes = 0 : i64,
         // The metadata encodes 2 parameter and two return values.
         metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
-    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
+    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
     "tf_device.launch"() ( {
       "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
       tf_device.return
@@ -254,7 +254,7 @@
         "tf_device.launch"() ( {
           "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
                 {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
-                  : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
+                  : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
           tf_device.return
         }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
         tf_device.return
@@ -342,9 +342,9 @@
         NumDynamicShapes = 0 : i64,
         // The metadata encodes 2 parameter and two return values.
         metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
-        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
-      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
-    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
+        mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
+      tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<2x!tf.string>
+    }) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
     "tf_device.launch"() ( {
       "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
       tf_device.return
@@ -367,7 +367,7 @@
       "tf_device.launch"() ( {
         "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
               {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
-                : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
+                : (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<2x!tf.string>) -> ()
         tf_device.return
       }) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
       %ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
index 32a8000..d897c8c 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
@@ -173,7 +173,7 @@
   func @tail_single_outside_compiled_op() {
     // CHECK:      %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
     // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"
-    // CHECK-NEXT:   "tf.C"
+    // CHECK-NEXT:   "tf.NoOp"
     // CHECK-NEXT:   tf_device.return %[[A_OUT]]
     // CHECK-NEXT: {
     // CHECK-DAG:  num_cores_per_replica = 1
@@ -190,7 +190,7 @@
     "tf_device.cluster"() ( {
       %a = "tf.A"() : () -> tensor<i32>
       "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
-      "tf.C"() : () -> ()
+      "tf.NoOp"() : () -> ()
       tf_device.return
     }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
     return
@@ -200,7 +200,7 @@
   func @tail_single_outside_compiled_op_user() -> tensor<i32> {
     // CHECK:      %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
     // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"
-    // CHECK-NEXT:   "tf.C"
+    // CHECK-NEXT:   "tf.NoOp"
     // CHECK-NEXT:   tf_device.return %[[A_OUT]]
     // CHECK-NEXT: {
     // CHECK-DAG:  num_cores_per_replica = 1
@@ -217,7 +217,7 @@
     %cluster = "tf_device.cluster"() ( {
       %a = "tf.A"() : () -> tensor<i32>
       %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
-      "tf.C"() : () -> ()
+      "tf.NoOp"() : () -> ()
       tf_device.return %b : tensor<i32>
     }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
     // CHECK:      return %[[LAUNCH_OUT]]
@@ -262,7 +262,7 @@
     %b = "tf.B"() : () -> tensor<i32>
     // CHECK:      %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
     // CHECK-NEXT:   %[[C_OUT:.*]] = "tf.C"
-    // CHECK-NEXT:   %[[E_OUT:.*]] = "tf.E"
+    // CHECK-NEXT:   %[[E_OUT:.*]] = "tf.Const"
     // CHECK-NEXT:   tf_device.return %[[C_OUT]], %[[E_OUT]]
     // CHECK-NEXT: {
     // CHECK-DAG:  num_cores_per_replica = 1
@@ -279,7 +279,7 @@
     %cluster:5 = "tf_device.cluster"() ( {
       %c = "tf.C"()  : () -> tensor<i32>
       %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
-      %e = "tf.E"()  : () -> tensor<i32>
+      %e = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
       tf_device.return %a, %b, %c, %d, %e : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
     }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
     // CHECK:      return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1
@@ -320,14 +320,14 @@
   func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor<i32>) {
     // CHECK-NOT:  "tf_device.launch"
     // CHECK:      "tf_device.cluster"
-    // CHECK-NEXT:   "tf.A"
+    // CHECK-NEXT:   "tf.Identity"
     // CHECK-NEXT:   "tf.B"
-    // CHECK-NEXT:   "tf.C"
+    // CHECK-NEXT:   "tf.Identity"
     // CHECK-NEXT:   tf_device.return
     "tf_device.cluster"() ( {
-      %a = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
+      %a = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
       %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
-      "tf.C"(%b) : (tensor<i32>) -> ()
+      %c = "tf.Identity"(%b) : (tensor<i32>) -> tensor<i32>
       tf_device.return
     }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
     return
@@ -379,7 +379,7 @@
     // CHECK:        %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
     // CHECK-NEXT:     %[[B_OUT:.*]] = "tf.B"
     // CHECK-NEXT:     %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]])
-    // CHECK-NEXT:     "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
+    // CHECK-NEXT:     "tf.IdentityN"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
     // CHECK-NEXT:     tf_device.return %[[C_OUT]]
     // CHECK-NEXT:   {
     // CHECK-DAG:    num_cores_per_replica = 1
@@ -399,11 +399,72 @@
         %b = "tf.B"() : () -> tensor<i32>
         %c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
         %d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
-        %e = "tf.E"(%c, %a) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+        %e:2 = "tf.IdentityN"(%c, %a) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
         tf_device.return
       }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
       tf_device.return
     }
     return
   }
+
+  // CHECK-LABEL: func @side_effect_middle
+  func @side_effect_middle() {
+    // CHECK:      "tf_device.cluster"
+    // CHECK-NEXT:   "tf.A"
+    // CHECK-NEXT:   "tf.B"
+    // CHECK-NEXT:   "tf.C"
+    // CHECK-NEXT:   tf_device.return
+    "tf_device.cluster"() ( {
+      "tf.A"() : () -> ()
+      "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
+      "tf.C"() : () -> ()
+      tf_device.return
+    }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
+    return
+  }
+
+  // CHECK-LABEL: func @side_effect_head_no_operand
+  func @side_effect_head_no_operand() {
+    // CHECK:      %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
+    // CHECK-NEXT:   "tf.B"
+    // CHECK-NEXT:   %[[C_OUT:.*]] = "tf.C"
+    // CHECK-NEXT:   tf_device.return %[[C_OUT]]
+    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
+
+    // CHECK:      "tf_device.cluster"
+    // CHECK-NEXT:   "tf.Const"
+    // CHECK-NEXT:   "tf.D"(%[[HEAD_LAUNCH_OUT]])
+    // CHECK-NEXT:   tf_device.return
+
+    "tf_device.cluster"() ( {
+      %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
+      %c = "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> tensor<i32>
+      "tf.D"(%c) : (tensor<i32>) -> ()
+      tf_device.return
+    }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
+    return
+  }
+
+  // CHECK-LABEL: func @side_effect_tail_no_operand
+  func @side_effect_tail_no_operand() {
+    // CHECK:      %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
+    // CHECK-NEXT:   %[[A_OUT:.*]] = "tf.A"
+    // CHECK-NEXT:   "tf.Const"
+    // CHECK-NEXT:   tf_device.return %[[A_OUT]]
+
+    // CHECK:      "tf_device.launch"()
+    // CHECK-NEXT:   "tf.B"(%[[CLUSTER_OUT]])
+    // CHECK-NEXT:   "tf.C"
+    // CHECK-NEXT:   tf_device.return
+    // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
+    "tf_device.cluster"() ( {
+      %a = "tf.A"() : () -> tensor<i32>
+      "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
+      "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
+      %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+      tf_device.return
+    }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
+    return
+  }
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
index 1f516a2..2271bca 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir
@@ -512,6 +512,137 @@
     return %1 : tensor<?xi32>
   }
 
+  // Tests extraction of an outside compiled tf.IfRegion op where the entirety
+  // of tf.IfRegion op is outside compiled
+
+  // CHECK-LABEL: func @outside_compiled_tf_if
+  func @outside_compiled_tf_if(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    // CHECK:      %[[A_OUT:[0-9]*]] = "tf.A"
+    // CHECK:      %[[F_OUT:[0-9]*]] = "tf.F"
+    // CHECK:      %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+    // CHECK:        %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
+    // CHECK-NEXT:     "tf_device.launch"
+    // CHECK-NEXT:       %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
+    // CHECK-NEXT:       %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
+    // CHECK-SAME:       device_ordinal = 0
+    // CHECK-SAME:       key = "host_compute_channel_cluster1_args"
+    // CHECK-SAME:       (tensor<2x!tf.string>) -> (tensor<?xi32>, tensor<?xi32>, tensor<i1>)
+    // CHECK-NEXT:       tf.IfRegion"(%[[RECV_OUTPUT]]#2)
+    // CHECK:              "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]])
+    // CHECK:              "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]])
+    // CHECK-SAME:         device_ordinal = 0
+    // CHECK-SAME:         key = "host_compute_channel_cluster1_retvals"
+    // CHECK:          "tf_device.cluster"
+    // CHECK:            %[[A_OUTPUT:[0-9]*]] = "tf.A"
+    // CHECK:            %[[B_OUTPUT:[0-9]*]] = "tf.B"
+    // CHECK:            %[[G_OUTPUT:[0-9]*]] = "tf.G"
+    // CHECK:            "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]])
+    // CHECK-SAME:       recv_key = "host_compute_channel_cluster1_retvals"
+    // CHECK-SAME:       send_key = "host_compute_channel_cluster1_args"
+    // CHECK-SAME:       tpu_core = 0
+    %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+    %7 = "tf.F"() : () -> tensor<?xi32>
+
+    %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+      %2 = "tf_device.cluster"() ( {
+        %3 = "tf.A"() : () -> (tensor<?xi32>)
+        %4 = "tf.B"() : () -> (tensor<?xi32>)
+        %6 = "tf.G"() : () -> (tensor<i1>)
+
+        "tf.IfRegion"(%6) ({
+          "tf.D"(%4, %3, %7) {} : (tensor<?xi32>, tensor<?xi32>, tensor<?xi32>) -> ()
+          "tf.Yield"() : () -> ()
+        }, {
+          "tf.Yield"() : () -> ()
+        }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor<i1>) -> ()
+
+        %5 = "tf.E"() : () -> tensor<?xi32>
+        tf_device.return %5 : tensor<?xi32>
+      }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
+      tf_device.return %2 : tensor<?xi32>
+    }
+
+    return %1 : tensor<?xi32>
+  }
+
+  // Tests extraction of an outside compiled tf.IfRegion op where the entirety
+  // of tf.IfRegion op is outside compiled and wrapped inside another
+  // tf.IfRegion op
+
+  // CHECK-LABEL: func @outside_compiled_tf_if_nested
+  func @outside_compiled_tf_if_nested(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+    // CHECK:      %[[A_OUT:[0-9]*]] = "tf.A"
+    // CHECK:      %[[F_OUT:[0-9]*]] = "tf.F"
+    // CHECK:      %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+    // CHECK:        %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
+    // CHECK-NEXT:     "tf_device.launch"
+    // CHECK-NEXT:       %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
+    // CHECK-NEXT:       %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
+    // CHECK-SAME:       device_ordinal = 0
+    // CHECK-SAME:       key = "if_predicate_channel_cluster1_0"
+    // CHECK-SAME:       (tensor<2x!tf.string>) -> tensor<i1>
+    // CHECK-NEXT:       tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]])
+    // CHECK-NEXT:         %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
+    // CHECK-SAME:         device_ordinal = 0
+    // CHECK-SAME:         key = "host_compute_channel_cluster1_args"
+    // CHECK-SAME:         (tensor<2x!tf.string>) -> (tensor<?xi32>, tensor<i1>)
+    // CHECK-NEXT:         tf.IfRegion"(%[[RECV_OUTPUT]]#1)
+    // CHECK-NEXT:           "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]])
+    // CHECK:                "tf.Yield"() : () -> ()
+    // CHECK:                "tf.Yield"() : () -> ()
+    // CHECK:              "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]])
+    // CHECK-SAME:         device_ordinal = 0
+    // CHECK-SAME:         key = "host_compute_channel_cluster1_retvals"
+    // CHECK:          "tf_device.cluster"
+    // CHECK:            %[[A_OUTPUT:[0-9]*]] = "tf.A"
+    // CHECK:            %[[B_OUTPUT:[0-9]*]] = "tf.B"
+    // CHECK:            %[[G_OUTPUT:[0-9]*]] = "tf.G"
+    // CHECK:            "tf.XlaSendToHost"(%[[G_OUTPUT]])
+    // CHECK-SAME:       key = "if_predicate_channel_cluster1_0"
+    // CHECK-SAME:       (tensor<i1>) -> ()
+    // CHECK-NEXT:       "tf.IfRegion"(%[[G_OUTPUT]])
+    // CHECK:              %[[D_OUT:[0-9]*]] = "tf.D"
+    // CHECK-NEXT:         %[[F_OUT:[0-9]*]] = "tf.F"
+    // CHECK:              "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]])
+    // CHECK-SAME:         recv_key = "host_compute_channel_cluster1_retvals"
+    // CHECK-SAME:         send_key = "host_compute_channel_cluster1_args"
+    // CHECK-SAME:         tpu_core = 0
+    // CHECK:              "tf.Yield"() : () -> ()
+    // CHECK:              "tf.Yield"() : () -> ()
+    %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+    %7 = "tf.F"() : () -> tensor<?xi32>
+
+    %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+      %2 = "tf_device.cluster"() ( {
+        %3 = "tf.A"() : () -> (tensor<?xi32>)
+        %4 = "tf.B"() : () -> (tensor<?xi32>)
+        %6 = "tf.G"() : () -> (tensor<i1>)
+
+        "tf.IfRegion"(%6) ({
+          %8 = "tf.D"(%4, %3, %7) {} : (tensor<?xi32>, tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>)
+          %9 = "tf.F"(%4) {} : (tensor<?xi32>) -> (tensor<i1>)
+
+          "tf.IfRegion"(%9) ({
+            "tf.H"(%8, %7) : (tensor<?xi32>, tensor<?xi32>) -> ()
+            "tf.Yield"() : () -> ()
+          }, {
+            "tf.Yield"() : () -> ()
+          }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor<i1>) -> ()
+
+          "tf.Yield"() : () -> ()
+        }, {
+          "tf.Yield"() : () -> ()
+        }) {is_stateless = false} : (tensor<i1>) -> ()
+
+        %5 = "tf.E"() : () -> tensor<?xi32>
+        tf_device.return %5 : tensor<?xi32>
+      }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<?xi32>
+      tf_device.return %2 : tensor<?xi32>
+    }
+
+    return %1 : tensor<?xi32>
+  }
+
   // Tests extraction of a single outside compiled cluster inside a tf.IfRegion
   // op with return values.
 
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
index 2a0091c..ef7b52c 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
@@ -1262,15 +1262,15 @@
       // CHECK-NOT:"tf._TPUCompileMlirPlaceholderProgramKey"
       // CHECK:    "tf.E"(%[[COMPILE_OUTPUT]]#1
       %3 = "tf_device.parallel_execute"() ( {
-         %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<?x!tf.string>
-        "tf.D"(%program) : (tensor<?x!tf.string>) -> ()
+         %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<2x!tf.string>
+        "tf.D"(%program) : (tensor<2x!tf.string>) -> ()
         tf_device.return
       }, {
         %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
         tf_device.return %4 : tensor<?xi32>
       }, {
-        %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<?x!tf.string>
-        "tf.E"(%program) : (tensor<?x!tf.string>) -> ()
+        %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<2x!tf.string>
+        "tf.E"(%program) : (tensor<2x!tf.string>) -> ()
         tf_device.return
       }) : () -> (tensor<?xi32>)
       tf_device.return %3 : tensor<?xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
index 9107a64..9960ca7 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc
@@ -101,13 +101,16 @@
   pm.addPass(TFDevice::CreateResourceOpLiftingPass());
   pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
   pm.addPass(mlir::createInlinerPass());
+  pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
   pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
+  pm.addPass(CreateTPUExtractOutsideCompilationPass());
   pm.addPass(TF::CreateTFRegionControlFlowToFunctional());
 
   pm.addNestedPass<FuncOp>(tf_executor::CreateTFExecutorConstantSinkingPass());
   pm.addPass(TF::CreateResourceDeviceInferencePass());
   pm.addPass(TFDevice::CreateClusterOutliningPass());
   pm.addPass(CreateTPUDynamicPaddingMapperPass());
+  pm.addPass(CreateTPUResourceReadForWritePass());
   pm.addPass(CreateTPUShardingIdentificationPass());
   pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
   pm.addPass(CreateTPURewritePass());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc
index 2b8ab85..e85058a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc
@@ -39,6 +39,10 @@
 
 struct ClusterFormationPass
     : public PassWrapper<ClusterFormationPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<tf_device::TensorFlowDeviceDialect>();
+  }
+
   void runOnFunction() override;
 };
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
index b473787..cc24c98 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
@@ -240,7 +240,7 @@
       auto def_op = val.getDefiningOp();
 #ifndef NDEBUG
       auto exec_dialect =
-          function.getContext()->getRegisteredDialect("tf_executor");
+          function.getContext()->getLoadedDialect("tf_executor");
       assert(def_op->getDialect() == exec_dialect &&
              "unable to forward control dependencies");
 #endif
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc
index 175baeb..fbe0524 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc
@@ -91,7 +91,7 @@
 
     // Build the newly fused operation to replace the batch norm
     OperationState state(batch_norm.getLoc(),
-                         FusedBatchNormExOp::getOperationName());
+                         _FusedBatchNormExOp::getOperationName());
     state.addOperands(batch_norm.getOperands());
     if (side_input) state.operands.push_back(side_input);
     state.addTypes(batch_norm.getResultTypes());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc
index 9f67a3e..4e507c8 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc
@@ -104,7 +104,7 @@
 }
 
 void LaunchToDeviceAttributePass::runOnFunction() {
-  const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
+  const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
   if (!tf_dialect) {
     getFunction().emitError() << "'tf' dialect is not registered";
     return signalPassFailure();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
index ad241ef..e64206d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc
@@ -615,6 +615,10 @@
 };
 
 class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<TF::TensorFlowDialect>();
+  }
+
  public:
   LegalizeHloToTf() = default;
   LegalizeHloToTf(const LegalizeHloToTf &) {}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
index d8e1709..4385a2d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
@@ -56,18 +56,27 @@
   return DenseIntElementsAttr::get(ty, vals);
 }
 
-// Returns int or float DenseElementsAttr with scalar shape with the given
-// element type and the integer value.
+// Returns int, float, or complex DenseElementsAttr with scalar shape with the
+// given element type and the integer value.
 static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
   if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
     FloatAttr attr = FloatAttr::get(float_ty, raw_value);
     return DenseElementsAttr::get(scalar_ty, attr);
+  } else if (auto int_ty = ty.dyn_cast_or_null<IntegerType>()) {
+    IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
+    return DenseElementsAttr::get(scalar_ty, attr);
+  } else if (auto complex_ty = ty.dyn_cast_or_null<ComplexType>()) {
+    Type complex_element_ty = complex_ty.getElementType();
+    if (complex_element_ty.isF32()) {
+      return DenseElementsAttr::get(
+          scalar_ty, static_cast<std::complex<float>>(raw_value));
+    } else if (complex_element_ty.isF64()) {
+      return DenseElementsAttr::get(
+          scalar_ty, static_cast<std::complex<double>>(raw_value));
+    }
   }
-
-  auto int_ty = ty.cast<IntegerType>();
-  IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
-  return DenseElementsAttr::get(scalar_ty, attr);
+  llvm_unreachable("unsupported type");
 }
 
 // Returns float DenseElementsAttr with scalar shape with the specified value.
@@ -112,6 +121,36 @@
   return RankedTensorType::get(shape, ranked_ty.getElementType());
 }
 
+// Indexes a rank 1 tensor returning a rank 1 singleton.
+Value IndexRank1(PatternRewriter &rewriter, Location loc, Value tensor,
+                 int64_t index) {
+  auto index_val =
+      rewriter.create<TF::ConstOp>(loc, GetI64ElementsAttr({index}, &rewriter));
+  auto size_val =
+      rewriter.create<TF::ConstOp>(loc, GetI64ElementsAttr({1}, &rewriter));
+  auto type = RankedTensorType::get({1}, rewriter.getIntegerType(64));
+  return rewriter.create<TF::SliceOp>(loc, type, tensor, index_val, size_val);
+}
+
+// Converts a rank 1 tensor to individual Values, each rank 1 with size 1.
+void Rank1ToValues(PatternRewriter &rewriter, Location loc, int64_t length,
+                   Value tensor, SmallVectorImpl<Value> &out) {
+  for (int64_t i = 0; i < length; ++i) {
+    out.push_back(IndexRank1(rewriter, loc, tensor, i));
+  }
+}
+
+// Converts individual Values to a tensor of rank 1. Each input Value has rank 1
+// and size 1.
+Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype,
+                    ArrayRef<Value> vals) {
+  int64_t length = vals.size();
+  auto type = RankedTensorType::get({length}, dtype);
+  auto axis = rewriter.create<TF::ConstOp>(
+      loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
+  return rewriter.create<TF::ConcatV2Op>(loc, type, ValueRange(vals), axis);
+}
+
 // Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
 //
 // Note that to improve the parallelism, AddN op uses tree-based reduction.
@@ -384,6 +423,158 @@
   }
 };
 
+// Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))).
+//
+// Before rewrite:
+//   output = SpaceToBatchND(input, block_shape, paddings)
+// Let:
+//   [batch] + spatial_shape + remaining_shape = input.shape
+//   M = spatial_shape.rank
+// After rewrite:
+//   padded = zero-pad input with paddings
+//     The spatial_shape component of input.shape pads with paddings[*, 0]
+//     before each dimension, and paddings[*, 1] after each dimension.
+//   reshaped = reshape padded to:
+//     [batch]
+//     + [padded.shape[1]/block_shape[0], block_shape[0], ...,
+//        padded.shape[M]/block_shape[M-1], block_shape[M-1]]
+//     + remaining_shape
+//   permuted = transpose reshaped to:
+//     block_shape
+//     + [batch]
+//     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
+//     + remaining_shape
+//   result = reshape permuted to:
+//     [batch * product(block_shape)]
+//     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
+//     + remaining_shape
+class LowerSpaceToBatchNDOp : public OpRewritePattern<TF::SpaceToBatchNDOp> {
+ public:
+  using OpRewritePattern<TF::SpaceToBatchNDOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TF::SpaceToBatchNDOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto input_type = op.input().getType().cast<TensorType>();
+    if (!input_type.hasStaticShape()) {
+      return failure();
+    }
+    auto block_shape_type = op.block_shape().getType().cast<TensorType>();
+    if (!block_shape_type.hasStaticShape() ||
+        !block_shape_type.getElementType().isSignlessInteger(64)) {
+      return failure();
+    }
+    auto paddings_type = op.paddings().getType().cast<ShapedType>();
+    if (!paddings_type.getElementType().isSignlessInteger(64)) {
+      // TODO(b/157475606): Add support for 32 bit signless integer. Currently
+      // the ConcatV2Op receives arguments with inconsistent dtypes, and
+      // SliceOps have different output dtype than input dtype.
+      return failure();
+    }
+    int64_t input_rank = input_type.getRank();
+    int64_t block_rank = block_shape_type.getNumElements();
+    int64_t remaining_rank = input_rank - 1 - block_rank;
+    ArrayRef<int64_t> input_shape = input_type.getShape();
+    auto pad00 = rewriter.create<TF::ConstOp>(
+        loc, DenseElementsAttr::get<int64_t>(
+                 RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)),
+                 {0, 0}));
+    SmallVector<Value, 4> full_paddings_list{pad00, op.paddings()};
+    full_paddings_list.append(remaining_rank, pad00);
+    auto full_paddings_type =
+        RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64));
+    auto padded_axis = rewriter.create<TF::ConstOp>(
+        loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
+    // Extends paddings to all dimensions of input by adding 0s to non-block
+    // dimensions.
+    auto full_paddings = rewriter.create<TF::ConcatV2Op>(
+        loc, full_paddings_type, full_paddings_list, padded_axis);
+    SmallVector<int64_t, 4> padded_shape(input_rank, -1);
+    auto padded_type =
+        RankedTensorType::get(padded_shape, rewriter.getF32Type());
+    // padded = pad(input, full_paddings)
+    auto padded =
+        rewriter.create<TF::PadOp>(loc, padded_type, op.input(), full_paddings);
+    auto paddings_sum_type =
+        RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
+    // paddings_sum = paddings[*,0] + paddings[*,1]
+    auto paddings_sum = rewriter.create<TF::EinsumOp>(
+        loc, paddings_sum_type, ValueRange({full_paddings}), "ij->i");
+    // input_shape_tensor = input.shape
+    auto input_shape_tensor = rewriter.create<TF::ConstOp>(
+        loc,
+        DenseElementsAttr::get(
+            RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)),
+            input_shape));
+    // padded_shape_tensor is the shape of padded.
+    auto padded_shape_tensor =
+        rewriter.create<TF::AddOp>(loc, paddings_sum, input_shape_tensor);
+    SmallVector<Value, 4> padded_shape_vals;
+    Rank1ToValues(rewriter, loc, input_rank, padded_shape_tensor,
+                  padded_shape_vals);
+    SmallVector<Value, 4> block_shape_vals;
+    Rank1ToValues(rewriter, loc, input_rank, op.block_shape(),
+                  block_shape_vals);
+    SmallVector<Value, 4> outer_shape_vals;
+    for (int64_t i = 0; i < block_rank; ++i) {
+      // TODO(b/157475606): Insert dynamic check that the following division has
+      // remainder 0.
+      outer_shape_vals.push_back(rewriter.create<TF::DivOp>(
+          loc, padded_shape_vals[1 + i], block_shape_vals[i]));
+    }
+
+    SmallVector<Value, 6> reshaped_shape_vals{padded_shape_vals[0]};
+    for (int64_t i = 0; i < block_rank; ++i) {
+      reshaped_shape_vals.push_back(outer_shape_vals[i]);
+      reshaped_shape_vals.push_back(block_shape_vals[i]);
+    }
+    for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
+      reshaped_shape_vals.push_back(padded_shape_vals[i]);
+    }
+    auto reshaped_shape = ValuesToRank1(
+        rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
+
+    SmallVector<Value, 6> permutation_vals;
+    for (int64_t i = 0; i < block_rank; ++i) {
+      permutation_vals.push_back(rewriter.create<TF::ConstOp>(
+          loc, GetI64ElementsAttr({2 + 2 * i}, &rewriter)));
+    }
+    permutation_vals.push_back(
+        rewriter.create<TF::ConstOp>(loc, GetI64ElementsAttr({0}, &rewriter)));
+    for (int64_t i = 0; i < block_rank; ++i) {
+      permutation_vals.push_back(rewriter.create<TF::ConstOp>(
+          loc, GetI64ElementsAttr({1 + 2 * i}, &rewriter)));
+    }
+    for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
+      permutation_vals.push_back(rewriter.create<TF::ConstOp>(
+          loc, GetI64ElementsAttr({block_rank + i}, &rewriter)));
+    }
+    auto permutation = ValuesToRank1(rewriter, loc, rewriter.getIntegerType(64),
+                                     permutation_vals);
+
+    auto output_batch = padded_shape_vals[0];
+    for (int64_t i = 0; i < block_rank; ++i) {
+      output_batch =
+          rewriter.create<TF::MulOp>(loc, output_batch, block_shape_vals[i]);
+    }
+    SmallVector<Value, 4> output_shape_vals{output_batch};
+    for (int64_t i = 0; i < block_rank; ++i) {
+      output_shape_vals.push_back(outer_shape_vals[i]);
+    }
+    for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
+      output_shape_vals.push_back(padded_shape_vals[i]);
+    }
+    auto output_shape = ValuesToRank1(
+        rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
+    auto reshaped = rewriter.create<TF::ReshapeOp>(loc, padded, reshaped_shape);
+    auto permuted =
+        rewriter.create<TF::TransposeOp>(loc, reshaped, permutation);
+
+    rewriter.replaceOpWithNewOp<TF::ReshapeOp>(op, permuted, output_shape);
+    return success();
+  }
+};
+
 // Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints,
 // since we currently don't have an implementation that can use this
 // information. Adds appropriate casts where necessary to align element types
@@ -438,8 +629,7 @@
   LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op,
                                 PatternRewriter &rewriter) const override {
     Value result = op.x();
-    for (StringRef op_name :
-         op.op_names().getAsRange<StringAttr, StringRef>()) {
+    for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
       std::string full_name = "tf." + op_name.str();
       // All ops in the sequences have the same result type as the original
       // result type.
@@ -458,8 +648,8 @@
 void PopulateLoweringTFPatterns(MLIRContext *context,
                                 OwningRewritePatternList *patterns) {
   patterns->insert<LowerAddNOp, LowerDynamicStitchOp, LowerInvertPermutationOp,
-                   LowerPackOp, LowerSparseMatMulOp, Lower_UnaryOpsComposition>(
-      context);
+                   LowerPackOp, LowerSpaceToBatchNDOp, LowerSparseMatMulOp,
+                   Lower_UnaryOpsComposition>(context);
   populateWithGenerated(context, patterns);
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
index 6b7d717..f7a867f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
@@ -195,8 +195,7 @@
 // Reciprocal op patterns.
 //===----------------------------------------------------------------------===//
 
-// TODO(hinsu): Support complex and unsigned input types.
-def LowerReciprocal : Pat<(TF_ReciprocalOp TF_SintOrFpTensor:$x),
+def LowerReciprocal : Pat<(TF_ReciprocalOp $x),
                           (TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>;
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
index 34b3347..4438f19 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc
@@ -17,6 +17,7 @@
 #include <string>
 #include <utility>
 
+#include "llvm/Support/FormatVariadic.h"
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
@@ -24,6 +25,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 
@@ -116,15 +118,32 @@
 LogicalResult MarkUncompilableOps(
     const Dialect* tf_dialect, Block* block,
     llvm::DenseSet<OperationName>& supported_ops) {
+  // Automatically marked ops for outside compilation have
+  // `_xla_outside_compilation` attribute value of "auto" plus
+  // an increasing counter.  Manually marked ops for outside compilation only
+  // have an increasing counteri for the attribute value.  Therefore there is no
+  // collision in
+  // `_xla_outside_compilation` attribute between automatically and manually
+  // marking ops.
+  int outside_compiled_cluster_counter = 0;
   block->walk([&](Operation* op) {
     if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
-      op->setAttr(kXlaOutsideCompilationAttr,
-                  StringAttr::get("auto", op->getContext()));
+      op->setAttr(
+          kXlaOutsideCompilationAttr,
+          StringAttr::get(
+              llvm::formatv("auto{0}", outside_compiled_cluster_counter).str(),
+              op->getContext()));
+      outside_compiled_cluster_counter++;
     }
     if (llvm::isa<TF::IfRegionOp, TF::WhileRegionOp>(op)) {
       if (HasCapturedStringOperand(op)) {
-        op->setAttr(kXlaOutsideCompilationAttr,
-                    StringAttr::get("auto", op->getContext()));
+        op->setAttr(
+            kXlaOutsideCompilationAttr,
+            StringAttr::get(
+                llvm::formatv("auto{0}", outside_compiled_cluster_counter)
+                    .str(),
+                op->getContext()));
+        outside_compiled_cluster_counter++;
       }
     }
   });
@@ -152,13 +171,14 @@
 
 void MarkOpsForOutsideCompilation::runOnOperation() {
   auto module = getOperation();
-  const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
+  const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
   if (!tf_dialect) {
     getOperation().emitError() << "'tf' dialect is not registered";
     return signalPassFailure();
   }
   OwningRewritePatternList patterns;
   mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
+  TF::PopulateLoweringTFPatterns(module.getContext(), &patterns);
 
   // `supported_ops` contains the name of all of the ops that can potentially be
   // lowered into HLO on the device. This doesn't always mean that the op can
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc
index 527af09..3526049 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc
@@ -39,6 +39,10 @@
 
 struct ParallelizeEmbeddingParamsOpsPass
     : public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<tf_device::TensorFlowDeviceDialect>();
+  }
+
   void runOnFunction() override;
 };
 
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index 1825511..7dcb1ca 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -79,6 +79,11 @@
 // Performs specific fusion for GPU targets.
 std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
 
+// Create a pass that convert ops that copy tensors between devices, e.g.
+// tf.Identity.
+std::unique_ptr<OperationPass<mlir::FuncOp>>
+CreateTensorDeviceCopyConversionPass();
+
 struct LayoutOptimizationPipelineOptions
     : public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
   Option<std::string> force_data_format{
@@ -282,6 +287,10 @@
 // `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
 
+// Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources
+// the cluster only writes to.
+std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass();
+
 // Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
 // ops.
 std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
index ef75f90..d99279c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
@@ -438,7 +438,7 @@
 
 void ReplicateToIslandPass::runOnOperation() {
   auto module = getOperation();
-  const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
+  const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
   if (!tf_dialect) {
     module.emitError() << "'tf' dialect is not registered";
     return signalPassFailure();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index a4f41d0..e1b6940 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -54,6 +54,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/framework/types.pb.h"
@@ -116,12 +117,12 @@
 
 // Returns if the shape inference pass supports an op outside the TF dialect.
 bool IsSupportedNonTFOp(Operation* op) {
-  return isa<ReturnOp, tf_device::ReturnOp, tf_executor::EnterOp,
-             tf_executor::ExitOp, tf_executor::FetchOp, tf_executor::GraphOp,
-             tf_executor::IslandOp, tf_executor::LoopCondOp,
-             tf_executor::MergeOp, tf_executor::NextIterationSinkOp,
-             tf_executor::SwitchNOp, tf_executor::SwitchOp,
-             tf_executor::YieldOp>(op);
+  return isa<ReturnOp, tf_device::ReturnOp, tf_device::ClusterOp,
+             tf_device::LaunchOp, tf_executor::EnterOp, tf_executor::ExitOp,
+             tf_executor::FetchOp, tf_executor::GraphOp, tf_executor::IslandOp,
+             tf_executor::LoopCondOp, tf_executor::MergeOp,
+             tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
+             tf_executor::SwitchOp, tf_executor::YieldOp>(op);
 }
 
 // Returns whether a cast back would need to be inserted, e.g., whether the
@@ -597,7 +598,7 @@
                                bool propagate_caller_callee_constants)
     : graph_version_(graph_version),
       propagate_caller_callee_constants_(propagate_caller_callee_constants) {
-  tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
+  tf_dialect_ = context->getLoadedDialect<TensorFlowDialect>();
 }
 
 ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
@@ -745,6 +746,11 @@
     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
                                             op->getResults());
   }
+  if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
+    auto terminator = cluster_op.GetBody().getTerminator();
+    return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
+                                            op->getResults());
+  }
   if (op->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
     return RefineShapeForPassThroughOps(op);
   }
@@ -815,18 +821,16 @@
     return false;
   }
 
-  // Convert the operation to a NodeDef to be able to use the InferenceContext
+  // Convert the operation attributes to be able to use the InferenceContext
   // and the TensorFlow shape function.
-  auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef(
-      op, node_name, /*ignore_unregistered_attrs=*/true);
-  if (!node_def_or.ok()) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Error converting op '" << *op << "' to NodeDef: "
-               << node_def_or.status().error_message() << "\n");
+  tensorflow::AttrValueMap attrs;
+  auto attrs_status = tensorflow::GetAttrValuesFromOperation(
+      op, node_name, /*ignore_unregistered_attrs=*/true, &attrs);
+  if (!attrs_status.ok()) {
+    LLVM_DEBUG(llvm::dbgs() << "Error creating attribute map for '" << *op
+                            << "': " << attrs_status.error_message() << "\n");
     return false;
   }
-  std::unique_ptr<tensorflow::NodeDef> node_def =
-      std::move(node_def_or).ValueOrDie();
 
   // Collect an array with input values for constant operands and input shapes
   // for all the operands.
@@ -870,8 +874,8 @@
   // Perform the shape inference using an InferenceContext with the input
   // shapes. This object is abstracting the information that the ShapeInference
   // function operates on.
-  InferenceContext c(graph_version_, *node_def, op_reg_data->op_def,
-                     input_shapes, input_tensors,
+  InferenceContext c(graph_version_, tensorflow::AttrSlice(&attrs),
+                     op_reg_data->op_def, input_shapes, input_tensors,
                      /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
   auto status = c.Run(op_reg_data->shape_inference_fn);
   if (!status.ok()) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc
new file mode 100644
index 0000000..f14efeb
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc
@@ -0,0 +1,81 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
+#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/Pass/PassOptions.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
+
+namespace mlir {
+namespace TF {
+namespace {
+
+// Deletes the op and forwards the arguments.
+template <typename TF_Op>
+class PassThroughConversion : public mlir::OpConversionPattern<TF_Op> {
+ public:
+  explicit PassThroughConversion(MLIRContext *context)
+      : mlir::OpConversionPattern<TF_Op>(context) {}
+
+  LogicalResult matchAndRewrite(
+      TF_Op op, ArrayRef<mlir::Value> operands,
+      ConversionPatternRewriter &rewriter) const override {  // NOLINT
+    // Just forward the arguments to results.
+    rewriter.replaceOp(op, operands);
+    return success();
+  }
+};
+
+class TensorDeviceCopyConversionPass
+    : public PassWrapper<TensorDeviceCopyConversionPass, FunctionPass> {
+ public:
+  void runOnFunction() override {
+    mlir::OwningRewritePatternList patterns;
+    mlir::ConversionTarget target(getContext());
+
+    // TODO(tfrt-devs): when device placer is introduced in the lowering pass,
+    // we need to check if Identity op and it's previous op are placed on the
+    // same device. If not, we don't fold Identity op since it's used for tensor
+    // copying between devices.
+    patterns.insert<PassThroughConversion<TF::IdentityOp>,
+                    PassThroughConversion<TF::IdentityNOp>>(&getContext());
+
+    if (failed(applyPartialConversion(getFunction(), target, patterns))) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<mlir::FuncOp>>
+CreateTensorDeviceCopyConversionPass() {
+  return std::make_unique<TensorDeviceCopyConversionPass>();
+}
+
+static mlir::PassRegistration<TensorDeviceCopyConversionPass>
+    tensor_device_copy_pass(
+        "tf-tensor-device-copy",
+        "Handle ops that copy tensors between devices. E.g., tf.Identity.");
+
+}  // namespace TF
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc
index 2a770b2..f26887e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc
@@ -34,7 +34,7 @@
 
   void runOnFunction() override {
     Builder builder(&getContext());
-    Dialect* tf = getContext().getRegisteredDialect<TensorFlowDialect>();
+    Dialect* tf = getContext().getLoadedDialect<TensorFlowDialect>();
     getFunction().walk([&](Operation* op) {
       if (auto device_attr = op->getAttrOfType<StringAttr>("device")) {
         // We assign default device to ops with device attribute that is empty.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc
index 1e4caaf..52ac87e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc
@@ -20,6 +20,7 @@
 #include "mlir/IR/Identifier.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
@@ -43,6 +44,10 @@
 class GraphOptPass
     : public mlir::PassWrapper<GraphOptPass,
                                mlir::OperationPass<mlir::ModuleOp>> {
+  void getDependentDialects(mlir::DialectRegistry& registry) const override {
+    mlir::RegisterAllTensorFlowDialects(registry);
+  }
+
  public:
   explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
       : passes_(std::move(passes)) {}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
index f5bdd08..13be3ad 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
@@ -78,6 +78,10 @@
 struct TPUClusterFormation
     : public TF::PerFunctionAggregateAnalysisConsumerPass<
           TPUClusterFormation, TF::ResourceAliasAnalysis> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<tf_device::TensorFlowDeviceDialect>();
+  }
+
   void runOnFunction(
       FuncOp func,
       const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
index fed4002..8a70906 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
@@ -27,6 +27,7 @@
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Block.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
 #include "mlir/IR/Value.h"  // from @llvm-project
 #include "mlir/IR/Visitors.h"  // from @llvm-project
@@ -34,6 +35,7 @@
 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
@@ -118,7 +120,10 @@
 // computation or other ops that can be extracted, and have no operands from
 // other ops in the TPU computation that cannot be extracted.
 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
+    const TF::SideEffectAnalysis& side_effect_analysis,
     tf_device::ClusterOp cluster) {
+  const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
+      cluster.getParentOfType<FuncOp>());
   Region* cluster_region = &cluster.body();
   llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
 
@@ -127,6 +132,15 @@
     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
     // An outside compiled op can be extracted if its operands are not from
     // other ops in the cluster that cannot be extracted.
+
+    // Check if the side effecting op right before this side effecting op, if
+    // it is side effecting, can be head extracted. Because of op ordering due
+    // to side effects, if this is not true, this op cannot be head extracted.
+    auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
+    if (!predecessors.empty() &&
+        !head_outside_compiled_ops.contains(predecessors.back()))
+      continue;
+
     auto walk_result = cluster_op.walk([&](Operation* op) {
       for (Value operand : op->getOperands()) {
         Operation* operand_op = GetOpOfValue(operand);
@@ -168,11 +182,11 @@
 // Extracts and move outside compiled ops that have no dependencies in the
 // cluster to before the cluster.
 mlir::LogicalResult LiftHeadOutsideCompiledOps(
-    OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
-    tf_device::ClusterOp cluster, std::string* host_device,
-    bool* cluster_updated) {
+    OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
+    const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
+    std::string* host_device, bool* cluster_updated) {
   llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
-      FindOutsideCompiledOpsAtHead(cluster);
+      FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
   if (head_outside_compiled_ops.empty()) return success();
   if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
                                                          host_device)))
@@ -191,9 +205,12 @@
 // TPU computation or other ops that can be extracted, and have no results used
 // by other ops in the TPU computation that cannot be extracted.
 void FindOutsideCompiledOpsAtTailAndClusterResults(
+    const TF::SideEffectAnalysis& side_effect_analysis,
     tf_device::ClusterOp cluster,
     llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
     llvm::SmallVectorImpl<Value>* cluster_results) {
+  const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
+      cluster.getParentOfType<FuncOp>());
   Region* cluster_region = &cluster.body();
   llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
   Operation* terminator = cluster.GetBody().getTerminator();
@@ -205,6 +222,15 @@
   for (Operation& cluster_op : cluster_ops) {
     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
 
+    // Check if the side effecting op right after this side effecting op, if
+    // it is side effecting, can be tail extracted. Because of op ordering due
+    // to side effects, if this is not true, this op cannot be tail extracted.
+    auto successors = analysis.DirectControlSuccessors(
+        &cluster_op, [&terminator](Operation* op) { return op != terminator; });
+    if (!successors.empty() &&
+        !tail_outside_compiled_ops_set.contains(successors.front()))
+      continue;
+
     llvm::SmallVector<int, 4> results_to_forward;
     bool can_be_extracted =
         llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
@@ -293,13 +319,14 @@
 // Extracts and move outside compiled ops that do not create dependencies in the
 // cluster to after the cluster.
 mlir::LogicalResult LiftTailOutsideCompiledOps(
-    OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
-    std::string host_device, tf_device::ClusterOp* cluster,
-    bool* cluster_updated) {
+    OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
+    const mlir::TF::RuntimeDevices& devices, std::string host_device,
+    tf_device::ClusterOp* cluster, bool* cluster_updated) {
   llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
   llvm::SmallVector<Value, 4> cluster_results;
-  FindOutsideCompiledOpsAtTailAndClusterResults(
-      *cluster, &tail_outside_compiled_ops, &cluster_results);
+  FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
+                                                &tail_outside_compiled_ops,
+                                                &cluster_results);
   if (tail_outside_compiled_ops.empty()) return success();
 
   if (host_device.empty())
@@ -365,6 +392,7 @@
 };
 
 void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
+  auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
   // Get runtime devices information from the closest parent module.
   auto module = getOperation();
   mlir::TF::RuntimeDevices devices;
@@ -379,10 +407,12 @@
   for (tf_device::ClusterOp cluster : clusters) {
     std::string host_device;
     bool cluster_updated = false;
-    if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster,
-                                          &host_device, &cluster_updated)) ||
-        failed(LiftTailOutsideCompiledOps(&builder, devices, host_device,
-                                          &cluster, &cluster_updated)))
+    if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
+                                          devices, cluster, &host_device,
+                                          &cluster_updated)) ||
+        failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
+                                          devices, host_device, &cluster,
+                                          &cluster_updated)))
       return signalPassFailure();
     if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
   }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
index 8adafe0..b141a7d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc
@@ -314,21 +314,41 @@
   return launch_op;
 }
 
-// Extracts all externally provided operands of `cluster_ops`.
+// Extracts all externally provided operands of `host_cluster_ops`.
 llvm::SmallSetVector<Value, 4> GetExternalOperands(
-    llvm::ArrayRef<Operation*> cluster_ops) {
+    tf_device::ClusterOp tpu_cluster,
+    llvm::ArrayRef<Operation*> host_cluster_ops) {
   llvm::SmallSetVector<Value, 4> external_values;
 
-  for (Operation* op : cluster_ops) {
-    for (Value v : op->getOperands()) {
-      Operation* defining_op = v.getDefiningOp();
-      if (!defining_op) continue;
-      bool is_external = llvm::none_of(cluster_ops, [&](Operation* cluster_op) {
-        return defining_op == cluster_op;
-      });
+  for (Operation* host_cluster_op : host_cluster_ops) {
+    auto cluster_op_parent_region = host_cluster_op->getParentRegion();
+    host_cluster_op->walk([&](Operation* op) {
+      auto region = op->getParentRegion();
 
-      if (is_external) external_values.insert(v);
-    }
+      if (region == cluster_op_parent_region) {
+        // For op operands, add operand defining ops, if they are not included
+        // in `host_cluster_ops`.
+        for (Value v : op->getOperands()) {
+          Operation* defining_op = v.getDefiningOp();
+          if (!defining_op) continue;
+          bool is_external = llvm::none_of(
+              host_cluster_ops,
+              [&](Operation* cluster_op) { return defining_op == cluster_op; });
+
+          if (is_external) external_values.insert(v);
+        }
+      } else {
+        llvm::SetVector<Value> external_captured_inputs;
+        visitUsedValuesDefinedAbove(*region, *region, [&](OpOperand* operand) {
+          Region* parent_region = operand->get().getParentRegion();
+          if (!tpu_cluster.body().isAncestor(parent_region)) return;
+
+          external_captured_inputs.insert(operand->get());
+        });
+        external_values.insert(external_captured_inputs.begin(),
+                               external_captured_inputs.end());
+      }
+    });
   }
 
   return external_values;
@@ -494,7 +514,7 @@
         &builder, cluster_ops.back(), host_device);
 
     // Determine if there are any inputs that are provided out of cluster.
-    auto external_inputs = GetExternalOperands(cluster_ops);
+    auto external_inputs = GetExternalOperands(tpu_cluster, cluster_ops);
     auto external_outputs = GetExternalOutputs(cluster_ops);
 
     MoveOutsideCompiledOps(module, tpu_cluster, cluster.value().getFirst(),
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
new file mode 100644
index 0000000..cccd528
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc
@@ -0,0 +1,140 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
+
+namespace mlir {
+namespace TFTPU {
+
+// A pass that finds TPU clusters with write only resource access and adds an
+// associated resource read, so the resource can later be fused into TPUExecute.
+namespace {
+struct TPUResourceReadForWrite
+    : public PassWrapper<TPUResourceReadForWrite, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+};
+
+// Helper struct holding a resource value and its associated type.
+struct ResourceValueAndSubtype {
+  Value resource;
+  Type subtype;
+};
+
+// Finds resource handle and type for result if result writes to a resource.
+ResourceValueAndSubtype GetResourceWriteResult(
+    tf_device::ClusterFuncOp cluster_func, Value result) {
+  ResourceValueAndSubtype resource;
+  if (!result.hasOneUse()) return resource;
+  Operation* result_user = *result.getUsers().begin();
+  auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
+  if (!assign_var) return resource;
+
+  auto handle = assign_var.resource();
+  // Skip result if cluster writes to the same variable via multiple results.
+  for (Operation* handle_user : handle.getUsers()) {
+    if (handle_user == assign_var) continue;
+    auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
+    if (!assign_var_user) continue;
+    if (assign_var_user.value().getDefiningOp() == cluster_func)
+      return resource;
+  }
+
+  resource.resource = assign_var.resource();
+  resource.subtype = assign_var.value().getType();
+  return resource;
+}
+
+// Checks if resource is read by TPU cluster.
+bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
+                                Value resource) {
+  for (Operation* resource_user : resource.getUsers())
+    if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
+      for (Operation* read_user : read.value().getUsers())
+        if (read_user == cluster_func) return true;
+
+  return false;
+}
+
+void TPUResourceReadForWrite::runOnOperation() {
+  SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
+  getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
+    cluster_funcs.push_back(cluster_func);
+  });
+
+  OpBuilder builder(&getContext());
+  // Add resource reads for resource writes from TPU cluster where for such
+  // resources the TPU cluster does not read from.
+  for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
+    builder.setInsertionPoint(cluster_func);
+
+    SmallVector<Value, 4> read_operands;
+    for (Value result : cluster_func.getResults()) {
+      // TODO(lyandy): Update pass to use resource alias analysis.
+      auto resource_and_type = GetResourceWriteResult(cluster_func, result);
+      if (!resource_and_type.resource) continue;
+      if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
+        continue;
+      auto new_read = builder.create<TF::ReadVariableOp>(
+          resource_and_type.resource.getLoc(), resource_and_type.subtype,
+          resource_and_type.resource);
+      read_operands.push_back(new_read.value());
+    }
+
+    if (read_operands.empty()) continue;
+
+    // Update caller and function types with new read operands.
+    auto operands = llvm::to_vector<4>(cluster_func.getOperands());
+    operands.append(read_operands.begin(), read_operands.end());
+
+    auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
+        cluster_func.getLoc(), cluster_func.getResultTypes(), operands,
+        cluster_func.getAttrs());
+    cluster_func.replaceAllUsesWith(new_cluster_func);
+    FuncOp func = cluster_func.getFunc();
+    Block& block = func.front();
+    for (Value read_operand : read_operands)
+      block.addArgument(read_operand.getType());
+
+    func.setType(FunctionType::get(block.getArgumentTypes(),
+                                   func.getCallableResults(), &getContext()));
+    cluster_func.erase();
+  }
+}
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
+  return std::make_unique<TPUResourceReadForWrite>();
+}
+
+static PassRegistration<TPUResourceReadForWrite> pass(
+    "tf-tpu-resource-read-for-write",
+    "Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes "
+    "with no reads");
+
+}  // namespace TFTPU
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index ca77fea..3fb0075 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -30,6 +30,7 @@
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/OperationSupport.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
@@ -157,7 +158,9 @@
   // Serialize module and return.
   {
     llvm::raw_string_ostream os(*serialized_func_module);
-    module_for_func.get().print(os);
+    OpPrintingFlags print_flags;
+    print_flags.enableDebugInfo();
+    module_for_func.get().print(os, print_flags);
   }
   return success();
 }
@@ -409,12 +412,15 @@
   std::string txt_module;
   if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
 
-  auto result_type =
+  auto compilation_status_type =
       RankedTensorType::get({}, builder->getType<TF::StringType>());
+  auto program_type =
+      RankedTensorType::get({2}, builder->getType<TF::StringType>());
 
   auto compile_op = builder->create<TF::_TPUCompileMlirOp>(
-      cluster_func.getLoc(), /*compilation_status=*/result_type, /*program=*/
-      llvm::SmallVector<Type, 8>(num_cores_per_replica, result_type),
+      cluster_func.getLoc(),
+      /*compilation_status=*/compilation_status_type, /*program=*/
+      llvm::SmallVector<Type, 8>(num_cores_per_replica, program_type),
       compile_op_operands, txt_module, txt_metadata);
 
   return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op,
@@ -598,9 +604,9 @@
 // func @main(%arg0: tensor<i1>) {
 //   %0 = "tf.Shape"(%arg0) : (tensor<i1>) -> tensor<?xi32>
 //   %1:2 = "tf._TPUCompileMlir"(%0) {device = "/CPU:0"} :
-//            (tensor<?xi32>) -> (tensor<!tf.string>, tensor<!tf.string>)
+//            (tensor<?xi32>) -> (tensor<!tf.string>, tensor<2x!tf.string>)
 //   %2 = "tf.TPUExecute"(%arg0, %1#0) {device = "/TPU:0"} :
-//            (tensor<i1>, tensor<!tf.string>) -> tensor<i1>
+//            (tensor<i1>, tensor<2x!tf.string>) -> tensor<i1>
 //   return
 // }
 //
@@ -624,9 +630,9 @@
 //                              {n = 2 : i32, devices = ["/TPU:0", "/TPU:1"]} {
 //     %1 = "tf.Shape"(%ri) : (tensor<i1>) -> tensor<?xi32>
 //     %2:2 = "tf._TPUCompileMlir"(%1) {device = "/CPU:0"} :
-//              (tensor<?xi32>) -> (tensor<!tf.string>, tensor<!tf.string>)
+//              (tensor<?xi32>) -> (tensor<!tf.string>, tensor<2x!tf.string>)
 //     %3 = "tf.TPUExecute"(%ri, %2#0) :
-//            (tensor<i1>, tensor<!tf.string>) -> tensor<i1>
+//            (tensor<i1>, tensor<2x!tf.string>) -> tensor<i1>
 //     tf_device.return %3 : tensor<i1>
 //   }
 //   return
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
index 0a69987..b65f07c 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
@@ -43,6 +43,10 @@
 
 class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
                            BreakUpIslands, TF::SideEffectAnalysis> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<tf_executor::TensorFlowExecutorDialect>();
+  }
+
  public:
   void runOnFunction(FuncOp func,
                      const TF::SideEffectAnalysis::Info& side_effect_analysis);
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
index 571d5e3..0445dbb 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
@@ -49,6 +49,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/graph_to_functiondef.h"
@@ -80,46 +81,14 @@
 constexpr char kDeviceAttr[] = "tf.device";
 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
 
-bool IsLegalChar(char c, bool first_char) {
-  if (isalpha(c)) return true;
-  if (isdigit(c)) return true;
-  if (c == '.') return true;
-  if (c == '_') return true;
-
-  // First character of a node name can only be a letter, digit, dot or
-  // underscore.
-  if (first_char) return false;
-
-  if (c == '/') return true;
-  if (c == '-') return true;
-
-  return false;
-}
-
-// Convert characters in name that are considered illegal in TensorFlow Node
-// name to '.'.
-std::string LegalizeNodeName(llvm::StringRef name) {
-  assert(!name.empty() && "expected non-empty name");
-
-  std::string legalized_name;
-  bool first = true;
-  for (auto c : name) {
-    if (IsLegalChar(c, first)) {
-      legalized_name += c;
-    } else {
-      legalized_name += '.';
-    }
-    first = false;
-  }
-
-  return legalized_name;
-}
-
 // OpOrArgLocNameMapper that legalizes the returned name.
 class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
  private:
   std::string GetName(OpOrVal op_or_val) override {
-    return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val));
+    std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
+    assert(!name.empty() && "expected non-empty name");
+    mlir::LegalizeNodeName(name);
+    return name;
   }
 };
 
@@ -523,13 +492,14 @@
       if (index >= num_data_results) break;
       // TODO(jpienaar): If there is a result index specified, ensure only one
       // and that it matches the result index of the op.
-      std::string orig_name(output_names[index]);
-      auto tensor_id = ParseTensorName(orig_name);
-      auto name = LegalizeNodeName(
-          llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
+      std::string name(output_names[index]);
+      auto tensor_id = ParseTensorName(name);
+      std::string tensor_id_node(tensor_id.node());
+      assert(!tensor_id_node.empty() && "expected non-empty name");
+      mlir::LegalizeNodeName(tensor_id_node);
 
       // Ensure name does not get reused.
-      (void)exporter.op_to_name_.GetUniqueName(name);
+      (void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
     }
   }
 
@@ -537,8 +507,9 @@
     TF_RET_CHECK(input_names.size() == block.getNumArguments());
     for (const auto& it : llvm::enumerate(function.getArguments())) {
       // TODO(lyandy): Update when changing feed/fetch import.
-      std::string orig_name(input_names[it.index()]);
-      std::string name = LegalizeNodeName(orig_name);
+      std::string name(input_names[it.index()]);
+      assert(!name.empty() && "expected non-empty name");
+      mlir::LegalizeNodeName(name);
       auto tensor_id = ParseTensorName(name);
       TF_RET_CHECK(tensor_id.index() == 0)
           << "input port designation not supported";
@@ -726,7 +697,7 @@
       mlir::Identifier::get("main", module.getContext());
   absl::optional<mlir::FuncOp> entry_func;
   FunctionDefLibrary flib;
-  auto tf_dialect = module.getContext()->getRegisteredDialect("tf");
+  auto tf_dialect = module.getContext()->getLoadedDialect("tf");
   for (auto function : module.getOps<mlir::FuncOp>()) {
     if (function.isExternal())
       return errors::FailedPrecondition("External functions not supported");
@@ -799,7 +770,7 @@
 stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef(
     mlir::FuncOp func, const GraphExportConfig& configs,
     FunctionDef* function_def) {
-  Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf");
+  Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf");
   FunctionDefLibrary flib;
   TF_RETURN_IF_ERROR(
       Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib));
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
index 3ca06e5..727831a 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
@@ -25,8 +25,8 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
-#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
 #include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -34,7 +34,6 @@
 namespace tensorflow {
 
 namespace {
-using stream_executor::port::StatusOr;
 
 // Sets type list attribute with the given `name` to the given `types`. If the
 // attribute already exists with a different value, returns an error.
@@ -90,7 +89,7 @@
 // definitions and isn't a header file.
 #include "tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator.inc"
 
-// Collect all the unregistered attributes for an TF dialect operation.
+// Collects all the unregistered attributes for an TF dialect operation.
 // Attributes "name" and "device" are not included because they are not part
 // of an TF op attributes.
 Status GetUnregisteredAttrs(
@@ -123,17 +122,10 @@
   return Status::OK();
 }
 
-}  // namespace
-
-StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
-    mlir::Operation* inst, llvm::StringRef name,
-    bool ignore_unregistered_attrs) {
-  // Use auto generated function to populate derived attribute.
-  //
-  // Note: This only populates derived attributes for TensorFlow ops that are
-  // generated using the TableGen. Manually defined ops should have all the
-  // attributes present as native MLIR op attributes.
-
+// Collects all attribute names to ignore in an MLIR operation when exporting to
+// a TensorFlow NodeDef.
+StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
+    mlir::Operation* inst, bool ignore_unregistered_attrs) {
   // The elements are owned by the MLIRContext.
   absl::flat_hash_set<absl::string_view> attrs_to_ignore;
   if (inst->isRegistered()) {
@@ -162,15 +154,25 @@
     attrs_to_ignore.insert(attr_name.data());
   }
 
-  TF_ASSIGN_OR_RETURN(auto node_def,
-                      GetOperationNodeDef(attrs_to_ignore, inst, name));
+  return attrs_to_ignore;
+}
+
+// Populates all derived attributes of a MLIR operation in a proto
+// map<string, AttrValue>.
+Status PopulateDerivedAttributes(mlir::Operation* inst,
+                                 bool ignore_unregistered_attrs,
+                                 AttrValueMap* attributes) {
+  // Use auto generated function to populate derived attribute.
+  //
+  // Note: This only populates derived attributes for TensorFlow ops that are
+  // generated using the TableGen. Manually defined ops should have all the
+  // attributes present as native MLIR op attributes.
 
   // If the operation is not registered, we won't be able to infer any attribute
   if (inst->isRegistered()) {
-    TF_RETURN_WITH_CONTEXT_IF_ERROR(
-        PopulateDerivedAttrs(inst, node_def->mutable_attr()),
-        "When populating derived attrs for ",
-        inst->getName().getStringRef().str());
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(PopulateDerivedAttrs(inst, attributes),
+                                    "When populating derived attrs for ",
+                                    inst->getName().getStringRef().str());
   }
 
   // Here we only add the shapes for the leading values with ShapedType,
@@ -185,10 +187,38 @@
       mlir::TF::ResultShapeRange output_shapes = {
           mlir::TF::ResultShapeIterator(begin),
           mlir::TF::ResultShapeIterator(end)};
-      TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes", output_shapes,
-                                           node_def->mutable_attr()));
+      TF_RETURN_IF_ERROR(
+          SetShapeAttribute("_output_shapes", output_shapes, attributes));
     }
   }
+
+  return Status::OK();
+}
+
+}  // namespace
+
+Status GetAttrValuesFromOperation(mlir::Operation* inst, llvm::StringRef name,
+                                  bool ignore_unregistered_attrs,
+                                  AttrValueMap* attributes) {
+  TF_ASSIGN_OR_RETURN(auto attrs_to_ignore,
+                      GetAttributesToIgnore(inst, ignore_unregistered_attrs));
+  TF_RETURN_WITH_CONTEXT_IF_ERROR(
+      ConvertAttributes(inst->getAttrs(), attrs_to_ignore, attributes),
+      "while converting attributes for node: ", name.str());
+  TF_RETURN_IF_ERROR(
+      PopulateDerivedAttributes(inst, ignore_unregistered_attrs, attributes));
+  return Status::OK();
+}
+
+StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
+    mlir::Operation* inst, llvm::StringRef name,
+    bool ignore_unregistered_attrs) {
+  TF_ASSIGN_OR_RETURN(auto attrs_to_ignore,
+                      GetAttributesToIgnore(inst, ignore_unregistered_attrs));
+  TF_ASSIGN_OR_RETURN(auto node_def,
+                      GetOperationNodeDef(attrs_to_ignore, inst, name));
+  TF_RETURN_IF_ERROR(PopulateDerivedAttributes(inst, ignore_unregistered_attrs,
+                                               node_def->mutable_attr()));
   return node_def;
 }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h
index a19ad1f..bd26017 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h
@@ -18,12 +18,22 @@
 
 #include "llvm/ADT/StringRef.h"
 #include "mlir/IR/Operation.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace tensorflow {
 
-// Converts an MLIR operation to TensorFlow NodeDef with given node name. This
+// Extracts the attributes of a MLIR operation and populates the converted
+// attributes in a proto map<string, AttrValue>.
+Status GetAttrValuesFromOperation(mlir::Operation* inst, llvm::StringRef name,
+                                  bool ignore_unregistered_attrs,
+                                  AttrValueMap* attributes);
+
+// Converts a MLIR operation to TensorFlow NodeDef with given node name. This
 // name should be unique to the graph it is being inserted to. If the
 // `ignore_unregistered_attrs` argument is set to true, the attributes which are
 // not in the op registry will be ignored. If the `ignore_unregistered_attrs`
@@ -31,9 +41,9 @@
 // ShapedType for the leading values with ShapedType in the results of the
 // nodes. Set it to true if the returned NodeDef will be executed by the linked
 // TF Eager runtime.
-stream_executor::port::StatusOr<std::unique_ptr<NodeDef>>
-ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name,
-                            bool ignore_unregistered_attrs);
+StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
+    mlir::Operation* inst, llvm::StringRef name,
+    bool ignore_unregistered_attrs);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 51f6374..b78f311 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -64,6 +64,7 @@
 #include "tensorflow/cc/saved_model/loader_util.h"
 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -141,6 +142,13 @@
   return false;
 }
 
+void LoadImporterDialects(mlir::MLIRContext& context) {
+  // Load dialects involved in the conversion
+  mlir::DialectRegistry registry;
+  mlir::RegisterAllTensorFlowDialects(registry);
+  registry.loadAll(&context);
+}
+
 // This class is used to generate new MLIR function name strings that are both
 // unique in the TF function library `flib_` and unique among the name strings
 // generated by the class object during its lifetime.
@@ -2136,6 +2144,7 @@
     mlir::MLIRContext* context, const Graph& graph,
     const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
     const GraphImportConfig& specs, llvm::StringRef func_name) {
+  LoadImporterDialects(*context);
   mlir::OwningModuleRef module =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
   std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
@@ -3192,6 +3201,7 @@
 StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
     SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
     mlir::MLIRContext* context, bool add_default_attributes) {
+  LoadImporterDialects(*context);
   GraphDebugInfo dummy_debug_info;
   const GraphDebugInfo& debug_info =
       saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info;
@@ -3271,6 +3281,7 @@
   static StatusOr<mlir::OwningModuleRef> Convert(
       const SavedModelBundle& bundle, absl::Span<std::string> exported_names,
       mlir::MLIRContext* context, bool upgrade_legacy) {
+    LoadImporterDialects(*context);
     SavedModelSignatureDefImporter importer(bundle, exported_names, context);
     TF_RETURN_IF_ERROR(importer.InitializeGraph(upgrade_legacy));
     return importer.ConvertSignatures();
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc
index b646e14..f63cb09 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc
@@ -23,6 +23,7 @@
 #include "llvm/Support/MemoryBuffer.h"
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/Translation.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
@@ -86,6 +87,9 @@
 }
 
 static TranslateFromMLIRRegistration mlir_to_graphdef_translate(
-    "mlir-to-graphdef", MlirToGraphdefTranslateFunction);
+    "mlir-to-graphdef", MlirToGraphdefTranslateFunction,
+    [](DialectRegistry& registry) {
+      mlir::RegisterAllTensorFlowDialects(registry);
+    });
 
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc
index 5236bde..22e6559 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc
@@ -20,6 +20,7 @@
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/Translation.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
 
 namespace mlir {
@@ -67,6 +68,7 @@
 // Test only translation to convert a simple MLIR module with a single TF
 // dialect op to NodeDef.
 static TranslateFromMLIRRegistration translate_from_mlir_registration(
-    "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef);
+    "test-only-mlir-to-tf-nodedef", MlirToTfNodeDef,
+    mlir::RegisterAllTensorFlowDialects);
 
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 99a5e32..0dbda2e 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -36,7 +36,9 @@
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@@ -276,16 +278,9 @@
   return Status::OK();
 }
 
-static void RegisterDialects() {
-  static bool init_once = []() {
-    mlir::registerDialect<mlir::StandardOpsDialect>();
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    mlir::registerDialect<mlir::shape::ShapeDialect>();
-    mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
-    mlir::registerDialect<mlir::mhlo::MhloDialect>();
-    return true;
-  }();
-  (void)init_once;
+static void RegisterDialects(mlir::DialectRegistry& registry) {
+  mlir::RegisterAllTensorFlowDialects(registry);
+  mlir::mhlo::registerAllMhloDialects(registry);
 }
 
 }  //  namespace
@@ -418,8 +413,8 @@
     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
     XlaCompilationResult* compilation_result,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
-  RegisterDialects();
   mlir::MLIRContext mlir_context;
+  RegisterDialects(mlir_context.getDialectRegistry());
   mlir::OwningModuleRef mlir_module;
 
   TF_RETURN_IF_ERROR(
@@ -506,9 +501,8 @@
     const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
     XlaCompilationResult* compilation_result,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
-  RegisterDialects();
-
   mlir::MLIRContext context;
+  RegisterDialects(context.getDialectRegistry());
   GraphImportConfig config;
   config.graph_as_function = true;
   // Disable shape inference during import as some TensorFlow op fails during
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
index 270ef2d..05e1f05 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
@@ -161,7 +161,7 @@
     default:
       // TODO(shpeisman): restructure code to reuse dialect pointer across
       // calls.
-      auto* dialect = builder->getContext()->getRegisteredDialect("tf");
+      auto* dialect = builder->getContext()->getLoadedDialect("tf");
       return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
   }
 
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
index bf96e3d..6266a5e 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
@@ -20,6 +20,7 @@
 
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@@ -33,16 +34,13 @@
 namespace tensorflow {
 namespace {
 
-static void RegisterDialects() {
-  static bool init_once = []() {
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    return true;
-  }();
-  (void)init_once;
+static void RegisterDialects(mlir::MLIRContext &context) {
+  context.loadDialect<mlir::TF::TensorFlowDialect>();
 }
 
 TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
   mlir::MLIRContext context;
+  RegisterDialects(context);
   mlir::Builder b(&context);
 
   PartialTensorShape output_shape =
@@ -52,6 +50,7 @@
 
 TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
   mlir::MLIRContext context;
+  RegisterDialects(context);
   mlir::Builder b(&context);
 
   PartialTensorShape output_shape = ConvertTypeToTensorShape(
@@ -61,6 +60,7 @@
 
 TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
   mlir::MLIRContext context;
+  RegisterDialects(context);
   mlir::Builder b(&context);
 
   PartialTensorShape output_shape = ConvertTypeToTensorShape(
@@ -77,8 +77,8 @@
 }
 
 TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) {
-  RegisterDialects();
   mlir::MLIRContext context;
+  RegisterDialects(context);
   mlir::Builder b(&context);
 
   // Create the sample tensor to convert.
@@ -123,9 +123,8 @@
 };
 
 TEST_F(ConvertTensorTest, Simple) {
-  RegisterDialects();
-
   mlir::MLIRContext context;
+  RegisterDialects(context);
   ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
       {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
   ASSERT_NO_FATAL_FAILURE(
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
index b23fbe7..19eb5b2 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
@@ -625,8 +625,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
@@ -641,8 +641,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
@@ -662,8 +662,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
@@ -682,8 +682,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
@@ -702,8 +702,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
@@ -725,8 +725,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
@@ -750,8 +750,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
@@ -777,8 +777,8 @@
 }
 
 TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
-  mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::MLIRContext context;
+  context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
   mlir::OwningModuleRef module_ref =
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
   mlir::OpBuilder builder(module_ref->getBodyRegion());
diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
index 1416ac0..e48b14a 100644
--- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
+++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
@@ -13,81 +13,29 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "mlir/IR/AsmState.h"  // from @llvm-project
-#include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "mlir/Pass/PassManager.h"  // from @llvm-project
-#include "mlir/Support/FileUtilities.h"  // from @llvm-project
+#include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
+#include "mlir/InitAllDialects.h"  // from @llvm-project
+#include "mlir/InitAllPasses.h"  // from @llvm-project
 #include "mlir/Support/MlirOptMain.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
 #include "tensorflow/compiler/mlir/init_mlir.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
+#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
 #include "tensorflow/core/platform/init_main.h"
-#include "tensorflow/core/platform/logging.h"
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> input_filename(llvm::cl::Positional,
-                                                 llvm::cl::desc("<input file>"),
-                                                 llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> output_filename(
-    "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
-    llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> split_input_file(
-    "split-input-file",
-    llvm::cl::desc("Split the input file into pieces and process each "
-                   "chunk independently"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> verify_diagnostics(
-    "verify-diagnostics",
-    llvm::cl::desc("Check that emitted diagnostics match "
-                   "expected-* lines on the corresponding line"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> verify_passes(
-    "verify-each",
-    llvm::cl::desc("Run the verifier after each transformation pass"),
-    llvm::cl::init(true));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> allowUnregisteredDialects(
-    "allow-unregistered-dialect",
-    llvm::cl::desc("Allow operation with no registered dialects"),
-    llvm::cl::init(false));
 
 int main(int argc, char **argv) {
   tensorflow::InitMlir y(&argc, &argv);
 
-  // Register various MLIR command line options.
-  mlir::registerAsmPrinterCLOptions();
-  mlir::registerMLIRContextCLOptions();
-  mlir::registerPassManagerCLOptions();
+  mlir::registerAllPasses();
 
-  // Parse pass names in main to ensure static initialization completed.
-  mlir::PassPipelineCLParser pass_pipeline("", "Compiler passes to run");
-
-  llvm::cl::ParseCommandLineOptions(argc, argv,
-                                    "TF MLIR modular optimizer driver\n");
-
-  // Set up the input file.
-  std::string error_message;
-  auto file = mlir::openInputFile(input_filename, &error_message);
-  QCHECK(file) << error_message;
-
-  auto output = mlir::openOutputFile(output_filename, &error_message);
-  QCHECK(output) << error_message;
-
-  if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline,
-                               split_input_file, verify_diagnostics,
-                               verify_passes, allowUnregisteredDialects)))
-    return 1;
-  output->keep();
-  return 0;
+  mlir::DialectRegistry registry;
+  mlir::registerAllDialects(registry);
+  mlir::RegisterAllTensorFlowDialects(registry);
+  mlir::mhlo::registerAllMhloDialects(registry);
+  registry.insert<mlir::shape::ShapeDialect>();
+  registry.insert<mlir::TFL::TensorFlowLiteDialect>();
+  registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
+  return failed(
+      mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry));
 }
diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
index caac8ea..3ea92a7 100644
--- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
+++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc
@@ -111,7 +111,6 @@
 
   if (import_saved_model_object_graph) {
     mlir::MLIRContext context;
-
     auto module_or = tensorflow::SavedModelObjectGraphToMlirImport(
         input_filename, tags, exported_names, &context);
     if (!module_or.status().ok()) return 1;
@@ -119,7 +118,6 @@
     module_or.ConsumeValueOrDie()->print(output->os());
   } else if (import_saved_model_signature_defs) {
     mlir::MLIRContext context;
-
     auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport(
         input_filename, tags, exported_names, &context, upgrade_legacy);
     if (!module_or.status().ok()) return 1;
diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD
index 7d3091f..4db9600 100644
--- a/tensorflow/compiler/mlir/tfjs/BUILD
+++ b/tensorflow/compiler/mlir/tfjs/BUILD
@@ -68,17 +68,6 @@
     alwayslink = 1,
 )
 
-cc_library(
-    name = "tensorflow_js_dialect_registration",
-    srcs = [
-        "ir/dialect_registration.cc",
-    ],
-    deps = [
-        ":tensorflow_js",
-    ],
-    alwayslink = 1,
-)
-
 gentbl(
     name = "tfjs_optimize_inc_gen",
     tbl_outs = [
@@ -107,7 +96,6 @@
     ],
     deps = [
         ":tensorflow_js",
-        ":tensorflow_js_dialect_registration",
         "//tensorflow/compiler/mlir/tensorflow",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Analysis",
@@ -149,7 +137,6 @@
     ],
     deps = [
         ":tensorflow_js",
-        ":tensorflow_js_dialect_registration",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
         "//tensorflow/compiler/mlir/tensorflow:export_utils",
@@ -236,3 +223,20 @@
         "@llvm-project//mlir:Support",
     ],
 )
+
+tf_cc_binary(
+    name = "tfjs-opt",
+    srcs = [
+        "tfjs_opt.cc",
+    ],
+    deps = [
+        ":tensorflow_js",
+        ":tensorflow_js_passes",
+        "//tensorflow/compiler/mlir:init_mlir",
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
+        "@llvm-project//mlir:MlirOptLib",
+        "@llvm-project//mlir:StandardOps",
+    ],
+)
diff --git a/tensorflow/compiler/mlir/tfjs/tests/BUILD b/tensorflow/compiler/mlir/tfjs/tests/BUILD
index a4ebc99..5789480 100644
--- a/tensorflow/compiler/mlir/tfjs/tests/BUILD
+++ b/tensorflow/compiler/mlir/tfjs/tests/BUILD
@@ -3,8 +3,11 @@
 package(licenses = ["notice"])
 
 glob_lit_tests(
-    data = [":test_utilities"],
-    driver = "@llvm-project//mlir:run_lit.sh",
+    data = [
+        ":test_utilities",
+        "@llvm-project//mlir:run_lit.sh",
+    ],
+    driver = "//tensorflow/compiler/mlir:run_lit.sh",
     test_file_exts = ["mlir"],
 )
 
@@ -13,7 +16,7 @@
     name = "test_utilities",
     testonly = True,
     data = [
-        "//tensorflow/compiler/mlir:tf-opt",
+        "//tensorflow/compiler/mlir/tfjs:tfjs-opt",
         "@llvm-project//llvm:FileCheck",
         "@llvm-project//llvm:not",
     ],
diff --git a/tensorflow/compiler/mlir/tfjs/tests/ops.mlir b/tensorflow/compiler/mlir/tfjs/tests/ops.mlir
index 0b72101..602f346 100644
--- a/tensorflow/compiler/mlir/tfjs/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/tfjs/tests/ops.mlir
@@ -1,4 +1,4 @@
-// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s
+// RUN: tfjs-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s
 
 // -----
 
diff --git a/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir
index 5f046dc..f4464dd 100644
--- a/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir
@@ -1,5 +1,5 @@
 // Run optimize pass only and check the results.
-// RUN: tf-opt %s -tfjs-optimize | FileCheck %s
+// RUN: tfjs-opt %s -tfjs-optimize | FileCheck %s
 
 // CHECK-LABEL: prelu_fusion
 func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
diff --git a/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc b/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc
new file mode 100644
index 0000000..c601312
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc
@@ -0,0 +1,33 @@
+/* Copyright 2020 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/InitAllPasses.h"  // from @llvm-project
+#include "mlir/Support/MlirOptMain.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/init_mlir.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
+
+int main(int argc, char **argv) {
+  tensorflow::InitMlir y(&argc, &argv);
+
+  mlir::registerAllPasses();
+
+  mlir::DialectRegistry registry;
+  registry.insert<mlir::StandardOpsDialect>();
+  registry.insert<mlir::TF::TensorFlowDialect>();
+  registry.insert<mlir::tfjs::TFJSDialect>();
+  return failed(mlir::MlirOptMain(argc, argv, "TF JS pass driver\n", registry));
+}
diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
index c03a684..a3678f7 100644
--- a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
@@ -37,6 +37,9 @@
 
 // Optimize TFJS operations in functions.
 struct Optimize : public PassWrapper<Optimize, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<TFJSDialect>();
+  }
   void runOnFunction() override;
 };
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
index e01c059..3c88318 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD
@@ -21,7 +21,7 @@
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LLVMDialect",
@@ -65,8 +65,10 @@
     srcs = ["tools/kernel-gen-opt/kernel-gen-opt.cc"],
     visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__pkg__"],
     deps = [
+        "//tensorflow/compiler/mlir/hlo:all_passes",
         "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
-        "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_dialect_registration",
+        "//tensorflow/compiler/mlir/tensorflow",
+        "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
         "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
index 82b0e61..3b6af7f 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
@@ -46,6 +46,7 @@
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 #include "tensorflow/compiler/xla/debug_options_flags.h"
@@ -246,21 +247,14 @@
   return Status::OK();
 }
 
-void RegisterDialects() {
-  static bool init_once = []() {
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    return true;
-  }();
-  (void)init_once;
-}
 }  // namespace
 
 StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
     llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
     llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
     llvm::ArrayRef<uint32_t> unroll_factors) {
-  RegisterDialects();
   mlir::MLIRContext context;
+  mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
   mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
 
   TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD
index 3a28d48..29939f2 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD
@@ -35,13 +35,3 @@
         "@llvm-project//mlir:SideEffects",
     ],
 )
-
-cc_library(
-    name = "tf_framework_dialect_registration",
-    srcs = ["dialect_registration.cc"],
-    deps = [
-        ":tf_framework_ops",
-        "@llvm-project//mlir:IR",
-    ],
-    alwayslink = 1,
-)
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/dialect_registration.cc
deleted file mode 100644
index a2e5955..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/dialect_registration.cc
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
-
-// Static initialization for TF Framework dialect registration.
-static mlir::DialectRegistration<
-    mlir::kernel_gen::tf_framework::TFFrameworkDialect>
-    tf_framework_ops;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc
index c1af356..85f1faf 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc
@@ -13,110 +13,27 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/ToolOutputFile.h"
-#include "mlir/IR/AsmState.h"  // from @llvm-project
-#include "mlir/IR/Dialect.h"  // from @llvm-project
-#include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/InitAllDialects.h"  // from @llvm-project
 #include "mlir/InitAllPasses.h"  // from @llvm-project
-#include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "mlir/Pass/PassManager.h"  // from @llvm-project
-#include "mlir/Support/FileUtilities.h"  // from @llvm-project
 #include "mlir/Support/MlirOptMain.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
+#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
 
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
-                                                llvm::cl::desc("<input file>"),
-                                                llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<std::string> outputFilename(
-    "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
-    llvm::cl::init("-"));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> splitInputFile(
-    "split-input-file",
-    llvm::cl::desc("Split the input file into pieces and process each "
-                   "chunk independently"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> verifyDiagnostics(
-    "verify-diagnostics",
-    llvm::cl::desc("Check that emitted diagnostics match "
-                   "expected-* lines on the corresponding line"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> verifyPasses(
-    "verify-each",
-    llvm::cl::desc("Run the verifier after each transformation pass"),
-    llvm::cl::init(true));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> allowUnregisteredDialects(
-    "allow-unregistered-dialect",
-    llvm::cl::desc("Allow operation with no registered dialects"),
-    llvm::cl::init(false));
-
-// NOLINTNEXTLINE
-static llvm::cl::opt<bool> showDialects(
-    "show-dialects", llvm::cl::desc("Print the list of registered dialects"),
-    llvm::cl::init(false));
-
 int main(int argc, char **argv) {
-  mlir::registerAllDialects();
   mlir::registerAllPasses();
-
-  mlir::mhlo::registerAllDialects();
+  mlir::mhlo::registerAllMhloPasses();
+  mlir::lmhlo::registerAllLmhloPasses();
   mlir::kernel_gen::registerKernelGenPasses();
 
-  llvm::InitLLVM y(argc, argv);
+  mlir::DialectRegistry registry;
+  mlir::registerAllDialects(registry);
+  mlir::mhlo::registerAllMhloDialects(registry);
+  mlir::RegisterAllTensorFlowDialects(registry);
+  registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
 
-  // Register any pass manager command line options.
-  mlir::registerAsmPrinterCLOptions();
-  mlir::registerPassManagerCLOptions();
-  mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
-
-  // Parse pass names in main to ensure static initialization completed.
-  llvm::cl::ParseCommandLineOptions(argc, argv,
-                                    "MLIR modular optimizer driver\n");
-
-  if (showDialects) {
-    mlir::MLIRContext context;
-    llvm::outs() << "Registered Dialects:\n";
-    for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
-      llvm::outs() << dialect->getNamespace() << "\n";
-    }
-    return 0;
-  }
-
-  // Set up the input file.
-  std::string errorMessage;
-  auto file = mlir::openInputFile(inputFilename, &errorMessage);
-  if (!file) {
-    llvm::errs() << errorMessage << "\n";
-    return 1;
-  }
-
-  auto output = mlir::openOutputFile(outputFilename, &errorMessage);
-  if (!output) {
-    llvm::errs() << errorMessage << "\n";
-    exit(1);
-  }
-
-  if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
-                         splitInputFile, verifyDiagnostics, verifyPasses,
-                         allowUnregisteredDialects))) {
-    return 1;
-  }
-  // Keep the output file if the invocation of MlirOptMain was successful.
-  output->keep();
-  return 0;
+  return failed(
+      mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));
 }
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index ef07c80..10a9d0c 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -67,6 +67,10 @@
 };
 
 struct BufferizePass : public BufferizePassBase<BufferizePass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<lmhlo::LmhloDialect>();
+  }
+
  public:
   void runOnOperation() override {
     OwningRewritePatternList patterns;
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc
index a0cfcae..a26198e 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc
@@ -36,6 +36,10 @@
 // * std.dealloc becomes tf_framework.dealloc_raw.
 class EmbedTFFrameworkPass
     : public EmbedTFFrameworkPassBase<EmbedTFFrameworkPass> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
+  }
+
  public:
   void runOnOperation() override {
     ModuleOp m = getOperation();
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
index 28d3647..d27f14d 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
@@ -38,6 +38,10 @@
 
 struct ShapeToDescriptorsPass
     : public ShapeToDescriptorsPassBase<ShapeToDescriptorsPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<scf::SCFDialect>();
+  }
+
  public:
   void runOnOperation() override {
     MLIRContext &ctx = getContext();
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc
index 42e8943..6e6d71f 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc
@@ -33,6 +33,10 @@
 
 class TestTFFrameworkToLLVMPass
     : public TestTFFrameworkLegalizeToLLVMPassBase<TestTFFrameworkToLLVMPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect>();
+  }
+
  public:
   void runOnOperation() override {
     ModuleOp m = getOperation();
diff --git a/tensorflow/compiler/mlir/utils/array_container_utils.h b/tensorflow/compiler/mlir/utils/array_container_utils.h
new file mode 100644
index 0000000..c1a8981
--- /dev/null
+++ b/tensorflow/compiler/mlir/utils/array_container_utils.h
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
+
+#include "absl/types/span.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+
+template <typename T>
+inline llvm::ArrayRef<T> SpanToArrayRef(absl::Span<const T> span) {
+  return llvm::ArrayRef<T>(span.data(), span.size());
+}
+
+template <typename T>
+inline llvm::ArrayRef<T> SpanToArrayRef(absl::Span<T> span) {
+  return llvm::ArrayRef<T>(span.data(), span.size());
+}
+
+template <typename T>
+inline llvm::MutableArrayRef<T> SpanToMutableArrayRef(absl::Span<T> span) {
+  return llvm::MutableArrayRef<T>(span.data(), span.size());
+}
+
+template <typename T>
+inline absl::Span<const T> ArrayRefToSpan(llvm::ArrayRef<T> ref) {
+  return absl::Span<const T>(ref.data(), ref.size());
+}
+
+template <typename T>
+inline absl::Span<T> MutableArrayRefToSpan(llvm::MutableArrayRef<T> ref) {
+  return absl::Span<T>(ref.data(), ref.size());
+}
+
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc
new file mode 100644
index 0000000..bc4e80f
--- /dev/null
+++ b/tensorflow/compiler/mlir/utils/name_utils.cc
@@ -0,0 +1,99 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
+
+#include <cctype>
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "mlir/IR/Identifier.h"  // from @llvm-project
+
+namespace mlir {
+
+namespace {
+// Checks if a character is legal for a TensorFlow node name, with special
+// handling if a character is at the beginning.
+bool IsLegalChar(char c, bool first_char) {
+  if (isalpha(c)) return true;
+  if (isdigit(c)) return true;
+  if (c == '.') return true;
+  if (c == '_') return true;
+
+  // First character of a node name can only be a letter, digit, dot or
+  // underscore.
+  if (first_char) return false;
+
+  if (c == '/') return true;
+  if (c == '-') return true;
+
+  return false;
+}
+}  // anonymous namespace
+
+void LegalizeNodeName(std::string& name) {
+  if (name.empty()) return;
+
+  if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.';
+
+  for (char& c : llvm::drop_begin(name, 1))
+    if (!IsLegalChar(c, /*first_char=*/false)) c = '.';
+}
+
+std::string GetNameFromLoc(Location loc) {
+  llvm::SmallVector<llvm::StringRef, 8> loc_names;
+  llvm::SmallVector<Location, 8> locs;
+  locs.push_back(loc);
+  bool names_is_nonempty = false;
+
+  while (!locs.empty()) {
+    Location curr_loc = locs.pop_back_val();
+
+    if (auto name_loc = curr_loc.dyn_cast<NameLoc>()) {
+      // Add name in NameLoc. For NameLoc we also account for names due to ops
+      // in functions where the op's name is first.
+      auto name = name_loc.getName().strref().split('@').first;
+      loc_names.push_back(name);
+      if (!name.empty()) names_is_nonempty = true;
+      continue;
+    } else if (auto call_loc = curr_loc.dyn_cast<CallSiteLoc>()) {
+      // Add name if CallSiteLoc's callee has a NameLoc (as should be the
+      // case if imported with DebugInfo).
+      if (auto name_loc = call_loc.getCallee().dyn_cast<NameLoc>()) {
+        auto name = name_loc.getName().strref().split('@').first;
+        loc_names.push_back(name);
+        if (!name.empty()) names_is_nonempty = true;
+        continue;
+      }
+    } else if (auto fused_loc = curr_loc.dyn_cast<FusedLoc>()) {
+      // Push all locations in FusedLoc in reverse order, so locations are
+      // visited based on order in FusedLoc.
+      auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
+      locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
+      continue;
+    }
+
+    // Location is not a supported, so an empty StringRef is added.
+    loc_names.push_back(llvm::StringRef());
+  }
+
+  if (names_is_nonempty)
+    return llvm::join(loc_names.begin(), loc_names.end(), ";");
+
+  return "";
+}
+
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/utils/name_utils.h b/tensorflow/compiler/mlir/utils/name_utils.h
new file mode 100644
index 0000000..4b08a41
--- /dev/null
+++ b/tensorflow/compiler/mlir/utils/name_utils.h
@@ -0,0 +1,35 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
+
+#include <string>
+
+#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/Location.h"  // from @llvm-project
+
+namespace mlir {
+
+// Converts characters in name that are considered illegal in TensorFlow Node
+// name to '.'.
+void LegalizeNodeName(std::string& name);
+
+// Creates a TensorFlow node name from a location.
+std::string GetNameFromLoc(Location loc);
+
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
diff --git a/tensorflow/compiler/mlir/utils/string_container_utils.h b/tensorflow/compiler/mlir/utils/string_container_utils.h
new file mode 100644
index 0000000..fb2fa06
--- /dev/null
+++ b/tensorflow/compiler/mlir/utils/string_container_utils.h
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
+#define TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
+
+#include "absl/strings/string_view.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+inline absl::string_view StringRefToView(llvm::StringRef ref) {
+  return absl::string_view(ref.data(), ref.size());
+}
+
+inline llvm::StringRef StringViewToRef(absl::string_view view) {
+  return llvm::StringRef(view.data(), view.size());
+}
+
+}  // namespace mlir
+
+#endif  // TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 4c14bcf..ec98d9d 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -133,7 +133,6 @@
         ":hlo_utils",
         ":mlir_hlo_to_hlo",
         "//tensorflow/compiler/mlir/hlo",
-        "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
         "//tensorflow/compiler/mlir/hlo:lhlo",
         "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:statusor",
@@ -239,7 +238,6 @@
     deps = [
         ":type_to_shape",
         "//tensorflow/compiler/mlir/hlo",
-        "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
         "//tensorflow/compiler/mlir/tensorflow:convert_type",
         "//tensorflow/compiler/mlir/tensorflow:error_util",
         "//tensorflow/compiler/tf2xla:common",
@@ -334,6 +332,7 @@
         ":mlir_hlo_to_hlo",
         "//tensorflow/compiler/jit:xla_cpu_jit",
         "//tensorflow/compiler/jit:xla_gpu_jit",
+        "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:status",
         "//tensorflow/compiler/xla:statusor",
@@ -342,6 +341,7 @@
         "//tensorflow/core:lib",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Translation",
     ],
     alwayslink = 1,
@@ -388,7 +388,6 @@
         ":xla_legalize_tf_with_tf2xla",
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
-        "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
         "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
         "//tensorflow/compiler/mlir/hlo:legalize_control_flow",
         "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation",
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
index db981bb..e0cc890 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
@@ -19,6 +19,7 @@
 #include <unordered_map>
 
 #include "absl/types/optional.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
 #include "mlir/IR/Function.h"  // from @llvm-project
@@ -62,7 +63,10 @@
       : context_(module.getContext()),
         module_(module),
         builder_(builder),
-        function_map_(function_map) {}
+        function_map_(function_map) {
+    context_->loadDialect<mlir::StandardOpsDialect>();
+    context_->loadDialect<mlir::mhlo::MhloDialect>();
+  }
 
   // Imports the given computation as a new function, if it hasn't been already
   // imported.
diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
index dd045da..9db5861 100644
--- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc
@@ -30,6 +30,12 @@
 
 namespace xla {
 
+HloModuleImporter::HloModuleImporter(mlir::ModuleOp module)
+    : module_(module), builder_(module.getContext()) {
+  module.getContext()->loadDialect<mlir::StandardOpsDialect>();
+  module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
+}
+
 Status HloModuleImporter::Import(const xla::HloModule& module) {
   // TODO(hinsu): Only import the entry computation here once all HLO ops with
   // reference to other computation are updated to have a region instead of a
diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h
index 69ac1e2..4012994 100644
--- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h
@@ -38,8 +38,7 @@
 // dialect. HloModuleImporter does not take ownership.
 class HloModuleImporter {
  public:
-  explicit HloModuleImporter(mlir::ModuleOp module)
-      : module_(module), builder_(module.getContext()) {}
+  explicit HloModuleImporter(mlir::ModuleOp module);
 
   // Import the HloModule into the MLIR Module.
   Status Import(const xla::HloModule& module);
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir
index 550b2ba..876a1bf 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir
@@ -169,7 +169,7 @@
   // CHECK:      "mhlo.send"([[ARG0]], [[INIT_TOKEN]])
   // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64}
   // CHECK-SAME: is_host_transfer = true
-  // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key"}
+  // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key_dtoh_0"}
   // CHECK-SAME: (tensor<i32>, !mhlo.token) -> !mhlo.token
   "tf.XlaSendToHost"(%arg0) {key = "send_key"} : (tensor<i32>) -> ()
   return
@@ -186,7 +186,7 @@
   // CHECK:      [[RECV_TUPLE:%.*]] = "mhlo.recv"([[INIT_TOKEN]])
   // CHECK-SAME: channel_id = {handle = 1 : i64, type = 3 : i64}
   // CHECK-SAME: is_host_transfer = true
-  // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key"}
+  // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key_htod_0"}
   // CHECK-SAME: (!mhlo.token) -> tuple<tensor<i32>, !mhlo.token>
 
 
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
index de1e592..8c8d999 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
@@ -220,13 +220,6 @@
   return %0 : tensor<3x3xf32>
 }
 
-// CHECK-LABEL: fft
-func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
-  // CHECK: "mhlo.fft"(%arg0)
-  %0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
-  return %0 : tensor<3x5x8xcomplex<f32>>
-}
-
 // CHECK-LABEL: reverse_sequence
 func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> {
   // CHECK-NOT: tf.ReverseSequence
@@ -298,6 +291,14 @@
   return %1 : tensor<1000xi32>
 }
 
+// CHECK-LABEL: multinomial
+func @multinomial(%arg0: tensor<2x4xf32>, %seed: tensor<i32>, %seed2: tensor<i32>) -> tensor<2x10xi32> {
+  // CHECK-NOT: tf.Multinomial
+  %samples = "tf.Const"() { value = dense<10> : tensor<i32> } : () -> tensor<i32>
+  %1 = "tf.Multinomial"(%arg0, %samples) {seed = 0, seed2 = 0}: (tensor<2x4xf32>, tensor<i32>) -> tensor<2x10xi32>
+  return %1 : tensor<2x10xi32>
+}
+
 // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
 // available but doesn't support this instance.
 }
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 2850f63..a05cb0b 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1500,6 +1500,35 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Elu op legalizations.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @elu
+func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
+  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
+  // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %arg0, %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"}
+  // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0)
+  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]])
+  // CHECK: return %[[RESULT]]
+  %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
+  return %0: tensor<1xf32>
+}
+
+// CHECK-LABEL: func @elu_grad
+// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
+func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
+  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
+  // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
+  // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"}
+  // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>}
+  // CHECK-DAG: %[[MULGRAD:.*]] = "mhlo.multiply"(%[[GRADIENTS]], %[[ADD1]])
+  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]])
+  // CHECK: return %[[RESULT]]
+  %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
+  return %2 : tensor<4x8xf32>
+}
+
+//===----------------------------------------------------------------------===//
 // Relu op legalizations.
 //===----------------------------------------------------------------------===//
 
@@ -1726,6 +1755,20 @@
 // Fast Fourier Transform op legalization.
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @fft_1D
+func @fft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
+  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex<f32>>
+  %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
+  return %0 : tensor<8xcomplex<f32>>
+}
+
+// CHECK-LABEL: func @ifft_1D
+func @ifft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
+  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex<f32>>
+  %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
+  return %0 : tensor<8xcomplex<f32>>
+}
+
 // CHECK-LABEL: func @rfft_1D
 func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
   %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
@@ -1734,6 +1777,48 @@
   return %0 : tensor<8xcomplex<f32>>
 }
 
+// CHECK-LABEL: func @rfft_1D_padded
+func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<8xcomplex<f32>> {
+  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
+  // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %2) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<8xf32>
+  // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32>
+  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
+  return %0 : tensor<8xcomplex<f32>>
+}
+
+// CHECK-LABEL: func @rfft_1D_sliced
+func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x8xcomplex<f32>> {
+  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
+  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32>
+  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<2x8xf32>
+  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x8xcomplex<f32>>
+  return %0 : tensor<2x8xcomplex<f32>>
+}
+
+// CHECK-LABEL: func @irfft_1D
+func @irfft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<5xf32> {
+  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
+  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex<f32>>) -> tensor<5xcomplex<f32>>
+  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex<f32>>
+  %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex<f32>>, tensor<1xi32>) -> tensor<5xf32>
+  return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: fft_1D_dynamic
+func @fft_1D_dynamic(%arg0: tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
+  // CHECK: "tf.FFT"
+  %0 = "tf.FFT"(%arg0) : (tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>>
+  return %0 : tensor<8xcomplex<f32>>
+}
+
+// CHECK-LABEL: rfft_1D_dynamic
+func @rfft_1D_dynamic(%arg0: tensor<?xf32>) -> tensor<8xcomplex<f32>> {
+  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
+  // CHECK: "tf.RFFT"
+  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<?xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
+  return %0 : tensor<8xcomplex<f32>>
+}
+
 //===----------------------------------------------------------------------===//
 // Shape op legalization.
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc
index 1a3f0c1..de8d6fc 100644
--- a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc
+++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc
@@ -42,13 +42,13 @@
  protected:
   XlaBuilderTest()
       : name_(SetupTest()),
-        context_(),
         module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
         builder_(&module_->getBodyRegion()),
-        xla_builder_(name_, builder_, module_->getLoc()) {}
+        xla_builder_(name_, builder_, module_->getLoc()) {
+    context_.loadDialect<mlir::mhlo::MhloDialect>();
+  }
 
   string SetupTest() {
-    mlir::registerDialect<mlir::mhlo::MhloDialect>();
     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
   }
 
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir
index 97c53cb..0c2aee5 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir
@@ -2,6 +2,6 @@
 
 // CHECK: Opaque elements attr not supported
 func @main() {
-  %0 = "tf.Const"() {value = opaque<"tf", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
+  %0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
   return
 }
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 3462b3b..0332dbe 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -15,6 +15,7 @@
 
 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
 
+#include <cctype>
 #include <cstddef>
 #include <cstdint>
 #include <iterator>
@@ -42,6 +43,7 @@
 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
 #include "mlir/IR/Types.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Support/LogicalResult.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@@ -69,6 +71,11 @@
 constexpr char kShardingAttr[] = "mhlo.sharding";
 
 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
+                    shape::ShapeDialect, StandardOpsDialect>();
+  }
+
  public:
   LegalizeTF() = default;
   LegalizeTF(const LegalizeTF &) {}
@@ -731,6 +738,27 @@
 }
 
 //===----------------------------------------------------------------------===//
+// FFT op utilities.
+//===----------------------------------------------------------------------===//
+// Returns the 1D i64 elements attribute populated with the inner-most dim of
+// the value.
+static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type,
+                                                 Builder *builder) {
+  if (type.getRank() == 0) {
+    return builder->getI64TensorAttr({});
+  }
+  return builder->getI64TensorAttr(type.getShape().back());
+}
+
+// Returns True if the inner-most dim is static.
+bool CheckInnerDimStatic(ShapedType type, Builder *builder) {
+  if (!type.hasRank()) {
+    return false;
+  }
+  return !type.isDynamicDim(type.getShape().size() - 1);
+}
+
+//===----------------------------------------------------------------------===//
 // MatMul op utilities.
 //===----------------------------------------------------------------------===//
 
@@ -1679,6 +1707,80 @@
   }
 };
 
+template <typename OpTy>
+class ConvertFFTOp : public OpRewritePattern<OpTy> {
+ public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    auto input_ty = op.input().getType().template cast<ShapedType>();
+    if (!input_ty.hasRank()) {
+      return failure();
+    }
+    auto input_shape = input_ty.getShape();
+    DenseIntElementsAttr fft_length_attr;
+    if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) {
+      return failure();
+    }
+    int64_t fft_length;
+    if (fft_length_attr.getNumElements() != 0) {
+      fft_length = fft_length_attr.getValue<IntegerAttr>(0).getInt();
+    } else {
+      return failure();
+    }
+
+    std::string fft_string = "RFFT";
+    if (typeid(OpTy) == typeid(TF::IRFFTOp)) {
+      fft_length = fft_length / 2 + 1;
+      fft_string = "IRFFT";
+    }
+    auto loc = op.getLoc();
+
+    // The inner-most dim cannot be dynamic.
+    if (input_ty.isDynamicDim(input_shape.size() - 1)) {
+      return failure();
+    }
+
+    auto expected_shape = llvm::to_vector<4>(input_shape.drop_back());
+    expected_shape.push_back(fft_length);
+
+    // Zero pad or truncate the last axis
+    Value reshaped = op.input();
+    SmallVector<int64_t, 4> begin_indices(input_shape.size(), 0);
+    SmallVector<int64_t, 4> strides(input_shape.size(), 1);
+
+    // Last dim larger than fft_length, slice the input
+    if (input_shape.back() > fft_length) {
+      reshaped = rewriter.create<SliceOp>(
+          op.getLoc(),
+          RankedTensorType::get(expected_shape, input_ty.getElementType()),
+          op.input(), GetI64ElementsAttr(begin_indices, &rewriter),
+          GetI64ElementsAttr(expected_shape, &rewriter),
+          GetI64ElementsAttr(strides, &rewriter));
+
+      // Last dim smaller than fft_length, zero-pad the input
+    } else if (input_ty.getShape().back() < fft_length) {
+      SmallVector<int64_t, 4> no_padding(input_shape.size(), 0);
+      SmallVector<int64_t, 4> padding(input_shape.size() - 1, 0);
+      padding.push_back(fft_length - input_shape.back());
+      Value zero =
+          GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter);
+      reshaped = rewriter.create<PadOp>(
+          loc, RankedTensorType::get(expected_shape, input_ty.getElementType()),
+          op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter),
+          GetI64ElementsAttr(padding, &rewriter),
+          GetI64ElementsAttr(no_padding, &rewriter));
+    }
+
+    rewriter.replaceOpWithNewOp<FftOp>(op, op.getType(), reshaped, fft_string,
+                                       rewriter.getI64TensorAttr(fft_length));
+    return success();
+  }
+};
+
+using ConvertRFFTOp = ConvertFFTOp<TF::RFFTOp>;
+using ConvertIRFFTOp = ConvertFFTOp<TF::IRFFTOp>;
+
 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
 // BatchNormGradOp for training and a sequence of binary ops for inference.
 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
@@ -5862,16 +5964,17 @@
       ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp,
       ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp,
       ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp,
-      ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
-      ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op,
-      ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
-      ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
-      ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp,
-      ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
-      ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
-      ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
-      ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp,
-      ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
+      ConvertRFFTOp, ConvertIRFFTOp, ConvertFusedBatchNormGradOp,
+      ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
+      ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op,
+      ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp,
+      ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp,
+      ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp,
+      ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
+      ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
+      ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp,
+      ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op,
+      ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
       ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
       ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
       ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
index 1d6ce36..6320ad2 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc
@@ -60,6 +60,10 @@
 // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
 class LegalizeTFCommunication
     : public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<mhlo::MhloDialect>();
+  }
+
  public:
   void runOnOperation() override;
 };
@@ -215,11 +219,17 @@
 }
 
 // Assigns frontend attributes holding information about data type and
-// TensorFlow rendezvous channel name.
-void SetFrontendAttributes(Operation* op, StringRef key, Type type) {
+// TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
+// handled differently as individual names are used per data send and receive.
+void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
+                           Type type, bool device_to_host) {
   MLIRContext* context = op->getContext();
 
-  auto rendezvous_name = StringAttr::get(key, context);
+  std::string formatted_key =
+      device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
+                     : llvm::formatv("{0}_htod_{1}", key, index).str();
+
+  auto rendezvous_name = StringAttr::get(formatted_key, context);
   auto rendezvous_name_attr = NamedAttribute(
       Identifier::get(kXlaHostTransferRendezvousNameAttr, context),
       rendezvous_name);
@@ -239,24 +249,10 @@
   op->setAttr(kFrontendAttributesAttr, frontend_attributes);
 }
 
-// Assigns frontend attributes holding information about data type and
-// TensorFlow rendezvous channel name specific to `tf._XlaHostComputeMlir`.
-// TensorFlow rendezvous channel name is handled differently as individual names
-// are used per data send and receive.
-void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
-                           Type type, bool device_to_host) {
-  std::string formatted_key =
-      device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
-                     : llvm::formatv("{0}_htod_{1}", key, index).str();
-
-  return SetFrontendAttributes(op, formatted_key, type);
-}
-
-// Creates a `mhlo.send` op for sending value `operand`. If `index` is set,
-// `key` will be rewritten with a suffix and index. If `tpu_core` is set, op
-// sharding for the respective device will be set.
+// Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set,
+// op sharding for the respective device will be set.
 Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
-                   Value operand, StringRef key, const Optional<size_t>& index,
+                   Value operand, StringRef key, size_t index,
                    const Optional<int64_t>& tpu_core, Value token) {
   // type 2 == DEVICE_TO_HOST
   auto channel_handle = ChannelHandle::get(
@@ -266,23 +262,18 @@
       loc, token.getType(), operand, token, channel_handle,
       /*is_host_transfer=*/builder.getBoolAttr(true));
 
-  if (index) {
-    SetFrontendAttributes(send, *index, key, operand.getType(),
-                          /*device_to_host=*/true);
-  } else {
-    SetFrontendAttributes(send, key, operand.getType());
-  }
+  SetFrontendAttributes(send, index, key, operand.getType(),
+                        /*device_to_host=*/true);
 
   if (tpu_core) SetOpSharding(send, *tpu_core);
 
   return send.getResult();
 }
 
-// Creates a `mhlo.recv` op for receiving a value. If `index` is set, `key` will
-// be rewritten with a suffix and index. If `tpu_core` is set, op sharding for
-// the respective device will be set.
+// Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op
+// sharding for the respective device will be set.
 Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
-                   Value result, StringRef key, const Optional<size_t>& index,
+                   Value result, StringRef key, size_t index,
                    const Optional<int64_t>& tpu_core, Value token) {
   // type 3 == HOST_TO_DEVICE
   auto channel_handle = ChannelHandle::get(
@@ -294,12 +285,10 @@
   auto recv =
       builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
                              /*is_host_transfer=*/builder.getBoolAttr(true));
-  if (index) {
-    SetFrontendAttributes(recv, *index, key, result_type,
-                          /*device_to_host=*/false);
-  } else {
-    SetFrontendAttributes(recv, key, result.getType());
-  }
+
+  SetFrontendAttributes(recv, index, key, result_type,
+                        /*device_to_host=*/false);
+
   if (tpu_core) SetOpSharding(recv, *tpu_core);
 
   auto get_tuple_element =
@@ -369,7 +358,7 @@
   builder.setInsertionPoint(send_to_host);
   token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
                        send_to_host.input(), send_to_host.key(),
-                       /*index=*/llvm::None, /*tpu_core=*/llvm::None, token);
+                       /*index=*/0, /*tpu_core=*/llvm::None, token);
 
   send_to_host.erase();
   return token;
@@ -381,7 +370,7 @@
   builder.setInsertionPoint(recv_from_host);
   token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
                        recv_from_host.output(), recv_from_host.key(),
-                       /*index=*/llvm::None, /*tpu_core=*/llvm::None, token);
+                       /*index=*/0, /*tpu_core=*/llvm::None, token);
 
   recv_from_host.erase();
   return token;
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 1f5207e..3d4f468 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -285,9 +285,19 @@
 // FFT op patterns.
 //===----------------------------------------------------------------------===//
 
-def : Pat<(TF_RFFTOp $input, (TF_ConstOp I32ElementsAttr:$fft_length)),
-          (HLO_FftOp $input, HLO_FFT_TYPE_RFFT,
-           (CastElementsToI64Elements $fft_length))>;
+def GetInnerDimFromValue : NativeCodeCall<
+  "GetInnerDimFromValue($0.getType().cast<ShapedType>(), &$_builder)">;
+
+def CheckInnerDimStatic
+  : Constraint<CPred<"CheckInnerDimStatic($0.getType().cast<ShapedType>(), &$_builder)">>;
+
+def : Pat<(TF_FFTOp:$res $input),
+          (HLO_FftOp $input, HLO_FFT_TYPE_FFT, (GetInnerDimFromValue $res)),
+          [(CheckInnerDimStatic $input)]>;
+
+def : Pat<(TF_IFFTOp:$res $input),
+          (HLO_FftOp $input, HLO_FFT_TYPE_IFFT, (GetInnerDimFromValue $res)),
+          [(CheckInnerDimStatic $input)]>;
 
 //===----------------------------------------------------------------------===//
 // GatherV2 op patterns.
@@ -436,6 +446,35 @@
           [(HLO_Tensor $res)]>;
 
 //===----------------------------------------------------------------------===//
+// Elu op patterns.
+//===----------------------------------------------------------------------===//
+
+def : Pat<(TF_EluOp AnyRankedTensor:$features),
+          (HLO_SelectOp
+           (HLOClient_BroadcastCompareOp
+              $features,
+              (HLO_ConstOp:$zero (GetScalarOfType<0> $features)),
+              (BinBroadcastDimensions $zero, $features),
+              HLO_COMPARISON_DIRECTION_GT),
+           $features,
+           (HLO_Expm1Op $features))>;
+
+def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features),
+           (HLO_SelectOp
+            (HLOClient_BroadcastCompareOp
+              $features,
+              (HLO_ConstOp:$zero (GetScalarOfType<0> $features)),
+              (BinBroadcastDimensions $zero, $features),
+              HLO_COMPARISON_DIRECTION_GT),
+            $gradients,
+            (HLO_MulOp
+             $gradients,
+             (HLOClient_BroadcastAddOp
+               $features,
+               (HLO_ConstOp:$one (GetScalarOfType<1> $features)),
+               (BinBroadcastDimensions $one, $features))))>;
+
+//===----------------------------------------------------------------------===//
 // Relu op patterns.
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 3ab89e4..b06edcd 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -146,10 +146,8 @@
     TypeID::get<TF::HSVToRGBOp>(),
     TypeID::get<TF::IFFT2DOp>(),
     TypeID::get<TF::IFFT3DOp>(),
-    TypeID::get<TF::IFFTOp>(),
     TypeID::get<TF::IRFFT2DOp>(),
     TypeID::get<TF::IRFFT3DOp>(),
-    TypeID::get<TF::IRFFTOp>(),
     TypeID::get<TF::IgammaOp>(),
     TypeID::get<TF::IgammacOp>(),
     TypeID::get<TF::IgammaGradAOp>(),
@@ -177,10 +175,10 @@
     TypeID::get<TF::MatrixTriangularSolveOp>(),
     TypeID::get<TF::MirrorPadOp>(),
     TypeID::get<TF::MulOp>(),
+    TypeID::get<TF::MultinomialOp>(),
     TypeID::get<TF::NegOp>(),
     TypeID::get<TF::NonMaxSuppressionV4Op>(),
     TypeID::get<TF::NotEqualOp>(),
-    TypeID::get<TF::MultinomialOp>(),
     TypeID::get<TF::PadOp>(),
     TypeID::get<TF::PlaceholderWithDefaultOp>(),
     TypeID::get<TF::PowOp>(),
@@ -195,7 +193,6 @@
     TypeID::get<TF::RGBToHSVOp>(),
     TypeID::get<TF::RandomUniformIntOp>(),
     TypeID::get<TF::RealDivOp>(),
-    TypeID::get<TF::ReciprocalOp>(),
     TypeID::get<TF::ReciprocalGradOp>(),
     TypeID::get<TF::Relu6GradOp>(),
     TypeID::get<TF::ResizeBilinearOp>(),
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 2246242..ef362d9 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -25,6 +25,7 @@
 #include "mlir/IR/AffineMap.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
@@ -34,6 +35,8 @@
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassOptions.h"  // from @llvm-project
 #include "mlir/Translation.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
@@ -133,6 +136,11 @@
 // MLIR LHLO.
 class XlaHloToLhloPass
     : public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
+                    mlir::lmhlo::LmhloDialect>();
+  }
+
  public:
   XlaHloToLhloPass() = default;
   XlaHloToLhloPass(const XlaHloToLhloPass&) {}
@@ -438,7 +446,7 @@
   builder_.setInsertionPointToEnd(block);
 
   auto return_op = builder_.create<ReturnOp>(builder_.getUnknownLoc());
-  builder_ = mlir::OpBuilder(return_op);
+  builder_ = OpBuilder(return_op);
 
   return Status::OK();
 }
@@ -449,6 +457,9 @@
 
 Status HloToLhloModule(const BufferAssignment& assignment,
                        const HloModule& hlo_module, ModuleOp module) {
+  module.getContext()
+      ->loadDialect<StandardOpsDialect, mhlo::MhloDialect,
+                    lmhlo::LmhloDialect>();
   HloComputation* computation = hlo_module.entry_computation();
 
   LhloDialectEmitter emitter(assignment, *computation, module);
@@ -462,15 +473,14 @@
   return computation->AcceptOrdered(&emitter, ordering);
 }
 
-mlir::OwningModuleRef HloTextToLhloTranslateFunction(
-    llvm::StringRef input, mlir::MLIRContext* context) {
+OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
+                                               MLIRContext* context) {
   StatusOr<std::unique_ptr<HloModule>> maybe_module =
       xla::ParseAndReturnUnverifiedModule(
           absl::string_view(input.data(), input.size()));
   TF_CHECK_OK(maybe_module.status());
 
-  mlir::OwningModuleRef module =
-      mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
+  OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context));
 
   TF_CHECK_OK(
       ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host"));
diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
index 158671a..d5c5986 100644
--- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
+++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
@@ -17,8 +17,11 @@
 
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/MemoryBuffer.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/Translation.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
@@ -173,11 +176,17 @@
 
 }  // namespace xla
 
+static void RegisterInputDialects(mlir::DialectRegistry& registry) {
+  registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect>();
+}
+
 static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate(
-    "mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction);
+    "mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction,
+    RegisterInputDialects);
 
 static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate(
-    "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction);
+    "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction,
+    RegisterInputDialects);
 
 static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
     "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index 4bd2dfd..41877d3 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -28,7 +28,6 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
@@ -61,7 +60,7 @@
           dtypes.as_dtype(x.dtype), shape=x.shape)
       with self.test_scope():
         chol = linalg_ops.cholesky(placeholder)
-      verification = math_ops.matmul(chol, chol, adjoint_b=True)
+      verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True)
       self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol)
 
   def testBasic(self):
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index 9d278cf..08aad66 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -29,7 +29,6 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
@@ -65,7 +64,8 @@
       with self.test_scope():
         x = linalg_ops.matrix_triangular_solve(
             placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
-      verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint)
+      verification = test_util.matmul_without_tf32(
+          placeholder_ca, x, adjoint_a=adjoint)
       self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
                                       placeholder_b, a, clean_a, b,
                                       verification, atol)
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index 5fcf254..b2d5db8 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -24,12 +24,17 @@
 import numpy as np
 
 from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_without_tensor_float_32(
+    "XLA QR op calls matmul. Also, matmul used for verification. Also with "
+    'TF32, mysterious "Unable to launch cuBLAS gemm" error occasionally occurs')
+# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error
 class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
 
   def AdjustedNorm(self, x):
@@ -73,7 +78,7 @@
 
     with self.session() as sess:
       x_tf = array_ops.placeholder(dtype)
-      with self.test_scope():
+      with self.device_scope():
         q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
       q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
 
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 8c31629..de97c6f 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -237,8 +237,8 @@
         'test_session not supported on XLATestCase, please use session')
 
   @contextlib.contextmanager
-  def test_scope(self):
-    """Test scope that runs tests on `self.device`.
+  def device_scope(self):
+    """Scope that runs tests on `self.device`.
 
     Yields:
       A scope to apply to the operators under test.
@@ -246,6 +246,15 @@
     with ops.device('device:{}:0'.format(self.device)):
       yield
 
+  def test_scope(self):
+    """Deprecated alias of `device_scope`.
+
+    This should be avoided as the name starts with `test`, so test runners
+    treat it as a test. This interferes with class decorators that operate on
+    each test method.
+    """
+    return self.device_scope()
+
 
 def Benchmark(tf_bench,
               builder_fn,
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index 0718bd8..44fb551 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -11,7 +11,6 @@
     "tf_custom_op_library_additional_deps",
     "tf_gen_op_libs",
     "tf_gen_op_wrapper_py",
-    "tf_gpu_kernel_library",
 )
 
 # buildifier: disable=same-origin-load
@@ -81,6 +80,7 @@
 
 cc_library(
     name = "common_utils",
+    srcs = ["common/utils.cc"],
     hdrs = ["common/utils.h"],
     copts = tf_copts(),
     deps = [
@@ -539,20 +539,6 @@
     ],
 )
 
-tf_gpu_kernel_library(
-    name = "plugin_cast",
-    srcs = ["plugin/plugin_cast.cu.cc"],
-    deps = [
-        ":trt_plugins",
-        "@com_google_absl//absl/strings",
-        "//tensorflow/core/platform:logging",
-        "//tensorflow/core:framework_lite",
-    ] + if_tensorrt([
-        "@local_config_cuda//cuda:cuda_headers",
-        "@local_config_tensorrt//:tensorrt",
-    ]),
-)
-
 tf_cuda_library(
     name = "trt_plugins",
     srcs = ["plugin/trt_plugin.cc"],
@@ -602,6 +588,7 @@
     link_in_framework = True,
     module_name = "_pywrap_py_utils",
     deps = [
+        ":common_utils",
         ":py_utils",
         "//tensorflow/core/platform:env",
         "//tensorflow/core/platform:logging",
diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.cc b/tensorflow/compiler/tf2tensorrt/common/utils.cc
new file mode 100644
index 0000000..6679ca0
--- /dev/null
+++ b/tensorflow/compiler/tf2tensorrt/common/utils.cc
@@ -0,0 +1,99 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
+
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+#include "absl/base/call_once.h"
+#include "absl/strings/str_join.h"
+#include "third_party/tensorrt/NvInferPlugin.h"
+#endif
+
+namespace tensorflow {
+namespace tensorrt {
+
+std::tuple<int, int, int> GetLinkedTensorRTVersion() {
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+  return std::tuple<int, int, int>{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
+                                   NV_TENSORRT_PATCH};
+#else
+  return std::tuple<int, int, int>{0, 0, 0};
+#endif
+}
+
+std::tuple<int, int, int> GetLoadedTensorRTVersion() {
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+  int ver = getInferLibVersion();
+  int major = ver / 1000;
+  ver = ver - major * 1000;
+  int minor = ver / 100;
+  int patch = ver - minor * 100;
+  return std::tuple<int, int, int>{major, minor, patch};
+#else
+  return std::tuple<int, int, int>{0, 0, 0};
+#endif
+}
+
+}  // namespace tensorrt
+}  // namespace tensorflow
+
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+namespace tensorflow {
+namespace tensorrt {
+namespace {
+
+void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
+  LOG(INFO) << "Linked TensorRT version: "
+            << absl::StrJoin(GetLinkedTensorRTVersion(), ".");
+  LOG(INFO) << "Loaded TensorRT version: "
+            << absl::StrJoin(GetLoadedTensorRTVersion(), ".");
+
+  bool plugin_initialized = initLibNvInferPlugins(trt_logger, "");
+  if (!plugin_initialized) {
+    LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
+                  "fail later.";
+  }
+
+  int num_trt_plugins = 0;
+  nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
+      getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
+  if (!trt_plugin_creator_list) {
+    LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
+  } else {
+    VLOG(1) << "Found the following " << num_trt_plugins
+            << " TensorRT plugins in registry:";
+    for (int i = 0; i < num_trt_plugins; ++i) {
+      if (!trt_plugin_creator_list[i]) {
+        LOG_WARNING_WITH_PREFIX
+            << "TensorRT plugin at index " << i
+            << " is not accessible (null pointer returned by "
+               "getPluginCreatorList for this plugin)";
+      } else {
+        VLOG(1) << "  " << trt_plugin_creator_list[i]->getPluginName();
+      }
+    }
+  }
+}
+
+}  // namespace
+
+void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
+  static absl::once_flag once;
+  absl::call_once(once, InitializeTrtPlugins, trt_logger);
+}
+
+}  // namespace tensorrt
+}  // namespace tensorflow
+#endif
diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.h b/tensorflow/compiler/tf2tensorrt/common/utils.h
index b428733..b76b75d 100644
--- a/tensorflow/compiler/tf2tensorrt/common/utils.h
+++ b/tensorflow/compiler/tf2tensorrt/common/utils.h
@@ -16,15 +16,33 @@
 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_
 #define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_
 
+#include <tuple>
+
+namespace tensorflow {
+namespace tensorrt {
+// Returns the compile time TensorRT library version information
+// {Maj, Min, Patch}.
+std::tuple<int, int, int> GetLinkedTensorRTVersion();
+
+// Returns the runtime time TensorRT library version information
+// {Maj, Min, Patch}.
+std::tuple<int, int, int> GetLoadedTensorRTVersion();
+}  // namespace tensorrt
+}  // namespace tensorflow
+
 #if GOOGLE_CUDA && GOOGLE_TENSORRT
 
 #include "tensorflow/core/platform/logging.h"
+#include "third_party/tensorrt/NvInfer.h"
 
 namespace tensorflow {
 namespace tensorrt {
 
 #define LOG_WARNING_WITH_PREFIX LOG(WARNING) << "TF-TRT Warning: "
 
+// Initializes the TensorRT plugin registry if this hasn't been done yet.
+void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger);
+
 }  // namespace tensorrt
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index f80c0f4..c0c3f25 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -1197,42 +1197,6 @@
   return status;
 }
 
-static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) {
-  static mutex plugin_mutex(LINKER_INITIALIZED);
-  static bool plugin_initialized = false;
-  mutex_lock lock(plugin_mutex);
-  if (plugin_initialized) return;
-
-  LOG(INFO) << "Linked TensorRT version: " << GetLinkedTensorRTVersion();
-  LOG(INFO) << "Loaded TensorRT version: " << GetLoadedTensorRTVersion();
-
-  plugin_initialized = initLibNvInferPlugins(trt_logger, "");
-  if (!plugin_initialized) {
-    LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may "
-                  "fail later.";
-  }
-
-  int num_trt_plugins = 0;
-  nvinfer1::IPluginCreator* const* trt_plugin_creator_list =
-      getPluginRegistry()->getPluginCreatorList(&num_trt_plugins);
-  if (!trt_plugin_creator_list) {
-    LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry.";
-  } else {
-    VLOG(1) << "Found the following " << num_trt_plugins
-            << " TensorRT plugins in registry:";
-    for (int i = 0; i < num_trt_plugins; ++i) {
-      if (!trt_plugin_creator_list[i]) {
-        LOG_WARNING_WITH_PREFIX
-            << "TensorRT plugin at index " << i
-            << " is not accessible (null pointer returned by "
-               "getPluginCreatorList for this plugin)";
-      } else {
-        VLOG(1) << "  " << trt_plugin_creator_list[i]->getPluginName();
-      }
-    }
-  }
-}
-
 // static
 StatusOr<std::unique_ptr<Converter>> Converter::Create(
     TrtPrecisionMode precision_mode, bool use_calibration,
@@ -1249,7 +1213,7 @@
     : precision_mode_(precision_mode),
       use_calibration_(use_calibration),
       use_implicit_batch_(use_implicit_batch) {
-  InitializeTrtPlugins(trt_logger);
+  MaybeInitializeTrtPlugins(trt_logger);
   this->RegisterOpConverters();
 }
 
@@ -1434,7 +1398,8 @@
   TF_RETURN_IF_ERROR(
       TrtPrecisionModeToName(precision_mode_, &precision_mode_str));
   string trt_network_name = StrCat(
-      "TF:", TF_VERSION_STRING, ", ", "TRT:", GetLoadedTensorRTVersion(), "-",
+      "TF:", TF_VERSION_STRING, ", ",
+      "TRT:", absl::StrJoin(GetLoadedTensorRTVersion(), "."), "-",
       "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_,
       ", ", "Max-Batch-Size:", max_batch_size, ", ",
       "Max-Workspace-Size:", max_workspace_size_bytes);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 72348c3..b127337 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -1709,12 +1709,12 @@
           std::tuple<TrtTestMode, DataType, TrtPrecisionMode>> {
  public:
   ParameterizedOpConverterTestBase()
-      : trt_mode(std::get<0>(GetParam())),
-        tf_type(std::get<1>(GetParam())),
-        converter_precision(std::get<2>(GetParam())) {}
+      : trt_mode_(std::get<0>(GetParam())),
+        tf_type_(std::get<1>(GetParam())),
+        converter_precision_(std::get<2>(GetParam())) {}
 
   void Reset() {
-    OpConverterTest::Reset(converter_precision, trt_mode);
+    OpConverterTest::Reset(converter_precision_, trt_mode_);
     input_data_.clear();
   }
 
@@ -1750,7 +1750,7 @@
     if (!partial_input_shape_dims.empty()) {
       partial_shape = partial_input_shape_dims;
     } else {
-      if (trt_mode == TrtTestMode::kDynamicShape) {
+      if (trt_mode_ == TrtTestMode::kDynamicShape) {
         // In dynamic shape mode we make all dims unknown.
         partial_shape = std::vector<int32>(dims.size(), -1);
       } else {
@@ -1776,7 +1776,7 @@
   void AddTestTensor(const string& name, const std::vector<int32>& dims,
                      const std::vector<T>& values = {},
                      const std::vector<int32>& partial_input_shape_dims = {}) {
-    AddTestTensor<T>(name, dims, tf_type, values, partial_input_shape_dims);
+    AddTestTensor<T>(name, dims, tf_type_, values, partial_input_shape_dims);
   }
 
   // Builds and runs the converted network. Checks output tensor shape. Tests
@@ -1796,7 +1796,7 @@
           TensorShapeUtils::MakeShape(expected_output_dims[i], &shape));
       string out_name = (n_output == 1) ? name : StrCat(name, ":", i);
       DataType out_tf_type =
-          out_tf_types.size() > i ? out_tf_types[i] : tf_type;
+          out_tf_types.size() > i ? out_tf_types[i] : tf_type_;
       InputOutputData data{
           out_name, ConstructTensor(shape.num_elements(), 0, out_tf_type)};
       output_data.push_back(data);
@@ -1840,9 +1840,9 @@
   }
 
  protected:
-  const TrtTestMode trt_mode;
-  const DataType tf_type;
-  const TrtPrecisionMode converter_precision;
+  const TrtTestMode trt_mode_;
+  const DataType tf_type_;
+  const TrtPrecisionMode converter_precision_;
   DataVec input_data_;
 };
 
@@ -2075,7 +2075,7 @@
                                      37.342354, 41.013527, 30.9738,   34.469433,
                                      45.018955, 48.59309,  59.369415, 63.04059};
   for (auto get_node_def : get_node_def_vec) {
-    NodeDef tmp_node_def = get_node_def(tf_type, "NCHW", true, 0);
+    NodeDef tmp_node_def = get_node_def(tf_type_, "NCHW", true, 0);
     std::string op_name = tmp_node_def.op();
     std::vector<TestParam> test_param{
         {"NHWC", 0, false, 0,
@@ -2097,7 +2097,7 @@
          errors::Unimplemented(StrCat("The input \"variance\" for ", op_name,
                                       " must be a constant, at my_batchnorm"))},
         {"NCHW", 0, false, 0.01}};  // The last one is the only test that runs.
-    if (trt_mode == TrtTestMode::kDynamicShape) {
+    if (trt_mode_ == TrtTestMode::kDynamicShape) {
       test_param.push_back(
           {"NCHW", 0, false, 0.01,
            errors::InvalidArgument(
@@ -2107,7 +2107,7 @@
     for (auto p : test_param) {
       Reset();
       NodeDef node_def =
-          get_node_def(tf_type, p.data_format, p.is_training, p.epsilon);
+          get_node_def(tf_type_, p.data_format, p.is_training, p.epsilon);
       for (int i = 0; i < node_input.size(); i++) {
         if (i == 0 || i == p.tensor_input_idx) {
           // The first input (x) is always added as a tensor, and it hase shape
@@ -2126,7 +2126,7 @@
           // the first arg is a tensor. TODO(tfeher) Check if one can relax this
           // restriction.
           Status expected_status =
-              (i != 0 && trt_mode == TrtTestMode::kImplicitBatch)
+              (i != 0 && trt_mode_ == TrtTestMode::kImplicitBatch)
                   ? errors::InvalidArgument(
                         StrCat("Batch size doesn't match for tensor ",
                                node_input[i].name,
@@ -2134,19 +2134,19 @@
                                "converter batch size: 3 vs 2"))
                   : Status::OK();
           std::vector<int> partial_input_shape;
-          if (i == 0 && trt_mode == TrtTestMode::kDynamicShape &&
+          if (i == 0 && trt_mode_ == TrtTestMode::kDynamicShape &&
               !p.keep_channel_unknown) {
             // keep channel dim static (known)
             partial_input_shape.resize(4, -1);
             partial_input_shape[1] = node_input[i].dims[1];
           }
-          AddTestTensor(node_input[i].name, node_input[i].dims, tf_type,
+          AddTestTensor(node_input[i].name, node_input[i].dims, tf_type_,
                         node_input[i].val, partial_input_shape,
                         expected_status);
 
         } else {
           AddTestWeights(node_input[i].name, node_input[i].dims,
-                         node_input[i].val, tf_type);
+                         node_input[i].val, tf_type_);
         }
       }
       TestOpConverter("my_batchnorm", node_def, node_input[0].dims,
@@ -2154,12 +2154,12 @@
                       ArrayFloatNear(expected_output));
     }
   }
-}  // namespace convert
+}
 
 TEST_P(OpConverterTest1, ConvertTranspose) {
   // Get the NodeDef for Transpose.
   Scope s = Scope::NewRootScope();
-  auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
+  auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
   auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights);
   const NodeDef& node_def = transpose.operation.node()->def();
@@ -2187,13 +2187,13 @@
           {},
           {3, 2, 1, 1},
           {3, 2, 1, 0},
-          (trt_mode == TrtTestMode::kImplicitBatch)
+          (trt_mode_ == TrtTestMode::kImplicitBatch)
               ? Status(error::UNIMPLEMENTED,
                        "Transpose at batch dimension is not supported")
               : Status::OK()},
       TestParamBase{{1, 1, 2, 3}, {}, {1, 3, 1, 2}, {0, 3, 1, 2}},
   };
-  if (trt_mode == TrtTestMode::kDynamicShape) {
+  if (trt_mode_ == TrtTestMode::kDynamicShape) {
     // Dynamic shape tests where some shapes are known
     test_params.push_back(TestParamBase{
         {1, 1, 2, 3}, {-1, 1, 2, -1}, {1, 3, 1, 2}, {0, 3, 1, 2}});
@@ -2317,19 +2317,22 @@
 TEST_P(OpConverterTest1, ConvertShape) {
   // Get the NodeDef for Shape op.
   Scope s = Scope::NewRootScope();
-  auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
+  auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
   auto shape = ops::Shape(s.WithOpName("my_shape"), input);
   const NodeDef& node_def = shape.operation.node()->def();
 
   Status conversion_status =
-      (trt_mode == TrtTestMode::kImplicitBatch)
+      (trt_mode_ == TrtTestMode::kImplicitBatch)
           ? errors::Unimplemented(
                 "Shape is only supported for explicit batch mode.")
           : Status::OK();
   std::vector<TestParamBase> test_params = {
-      TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status},
-      // Add input as weight (we use non empty param ({1}) to trigger this).
-      TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status},
+// TODO(b/166274212): Enable the test parameter for TensorRT 7.1.3.
+#if !IS_TRT_VERSION_GE(7, 1, 3, 0)
+    TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status},
+#endif
+    // Add input as weight (we use non empty param ({1}) to trigger this).
+    TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status},
   };
 
   auto input_is_weight = [](const TestParamBase p) { return !p.param.empty(); };
@@ -2343,7 +2346,7 @@
     // we use for the unit test have no actual input tensor when it is converted
     // to a TensorRT network.
     int n_elements = 0;
-    if (input_is_weight(p) || trt_mode != TrtTestMode::kExplicitBatch) {
+    if (input_is_weight(p) || trt_mode_ != TrtTestMode::kExplicitBatch) {
       // Calculate the number of elements for adding input data.
       n_elements = std::accumulate(p.input_dims.begin(), p.input_dims.end(), 1,
                                    std::multiplies<int>());
@@ -2352,7 +2355,7 @@
     if (!input_is_weight(p)) {
       AddTestTensor("input", p.input_dims, input_val);
     } else {
-      AddTestWeights("input", p.input_dims, input_val, tf_type);
+      AddTestWeights("input", p.input_dims, input_val, tf_type_);
     }
     TestOpConverter("my_shape", node_def, p.expected_output_dims, p.status,
                     p.runtime_status, ElementsAreArray(p.input_dims),
@@ -2617,7 +2620,7 @@
   for (const string& data_format : {"NHWC", "NCHW"}) {
     for (const int trt_input_rank : {1, 2, 3, 4}) {
       Reset();
-      NodeDef node_def = get_biasadd_nodedef(data_format, tf_type);
+      NodeDef node_def = get_biasadd_nodedef(data_format, tf_type_);
 
       // Add input, dims_array will be like {2, 1, ..., 1, 3}
       std::vector<int32> dims_array(trt_input_rank + 1, 1);
@@ -2639,7 +2642,7 @@
       for (int i = 0; i < channel_size; ++i) {
         bias[i] = i + 1;  // bias will be {1, 2, 3, ...}
       }
-      AddTestWeights("weights", {channel_size}, bias, tf_type);
+      AddTestWeights("weights", {channel_size}, bias, tf_type_);
 
       // Build and run the engine.
       std::vector<float> output_data;
@@ -2675,7 +2678,7 @@
 TEST_P(OpConverterTest2, ConvertBinary) {
   {
     AttrValue dtype;
-    dtype.set_type(tf_type);
+    dtype.set_type(tf_type_);
     // Both inputs are weights.
     Reset();
     NodeDef node_def =
@@ -2720,19 +2723,19 @@
         if (!op_test_info.count(op_name)) {
           FAIL() << "Binary op test map does not contain op " << op_name;
         }
-        NodeDef node_def = op_test_info[op_name].first(tf_type);
+        NodeDef node_def = op_test_info[op_name].first(tf_type_);
         std::vector<std::string> input_names;
         std::vector<std::vector<int>> input_dims;
         std::vector<std::vector<float>> input_values;
         if (operand_1_is_tensor) {
           AddTestTensor("input1", {2, 1, 2}, {3, 6, 3, 6});
         } else {
-          AddTestWeights("input1", {1, 2}, std::vector<float>{3, 6}, tf_type);
+          AddTestWeights("input1", {1, 2}, std::vector<float>{3, 6}, tf_type_);
         }
         if (operand_2_is_tensor) {
           AddTestTensor("input2", {2, 2, 1}, {2, 3, 2, 3});
         } else {
-          AddTestWeights("input2", {2, 1}, std::vector<float>{2, 3}, tf_type);
+          AddTestWeights("input2", {2, 1}, std::vector<float>{2, 3}, tf_type_);
         }
         TestOpConverter("my_binary", node_def, {2, 2, 2}, Status::OK(),
                         Status::OK(),
@@ -2939,10 +2942,10 @@
     // Input is weights, should fail.
     Reset();
     Scope s = Scope::NewRootScope();
-    auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
+    auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
     auto square = ops::Square(s.WithOpName("my_square"), input);
     NodeDef node_def = square.operation.node()->def();
-    AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type);
+    AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type_);
     RunValidationAndConversion(
         node_def, error::UNIMPLEMENTED,
         "The input \"x\" for Square must be a tensor, at my_square");
@@ -2951,7 +2954,7 @@
   Reset();
 
   Scope s = Scope::NewRootScope();
-  auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
+  auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
   auto square = ops::Square(s.WithOpName("my_square"), input);
   NodeDef node_def = square.operation.node()->def();
 
@@ -2964,7 +2967,7 @@
     inputs[i] = value;
     expected_outputs[i] = value * value;
   }
-  AddTestTensor("input", {1, 1, 20}, tf_type, inputs);
+  AddTestTensor("input", {1, 1, 20}, tf_type_, inputs);
 
   TestOpConverter("my_square", node_def, {1, 1, 20}, Status::OK(), Status::OK(),
                   ArrayFloatNear(expected_outputs, 0));
@@ -3091,7 +3094,7 @@
   {
     // Input is weights, should fail.
     Reset();
-    const NodeDef& node_def = CreateUnaryOp<ops::Relu>(tf_type);
+    const NodeDef& node_def = CreateUnaryOp<ops::Relu>(tf_type_);
     AddTestWeights<int32>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
     RunValidationAndConversion(
         node_def, error::UNIMPLEMENTED,
@@ -3148,7 +3151,7 @@
       FAIL() << "Activation op test map does not contain op " << op_name;
     }
     Reset();
-    NodeDef node_def = op_map[op_name].first(tf_type);
+    NodeDef node_def = op_map[op_name].first(tf_type_);
     const std::vector<float> input = {-100, -2, -1, 0, 1, 88};
     AddTestTensor("input", p.input_dims, input);
 
@@ -3176,7 +3179,7 @@
 TEST_P(OpConverterTest1, ConvertExpandDims) {
   // Get the NodeDef for ExpandDims.
   Scope s = Scope::NewRootScope();
-  auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
+  auto input = ops::Placeholder(s.WithOpName("input"), tf_type_);
   auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
   auto expanddims =
       ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
@@ -3204,7 +3207,7 @@
                     {},
                     {1, 1, 1, 2, 3},
                     {0},
-                    trt_mode == TrtTestMode::kImplicitBatch
+                    trt_mode_ == TrtTestMode::kImplicitBatch
                         ? Status(error::UNIMPLEMENTED,
                                  "TensorRT does not allow manipulation of the "
                                  "batch dimension, at my_expanddims")
@@ -3213,7 +3216,7 @@
                     {},
                     {1, 1, 1, 2, 3},
                     {-5},
-                    trt_mode == TrtTestMode::kImplicitBatch
+                    trt_mode_ == TrtTestMode::kImplicitBatch
                         ? Status(error::UNIMPLEMENTED,
                                  "TensorRT does not allow manipulation of the "
                                  "batch dimension, at my_expanddims")
@@ -3251,7 +3254,7 @@
 }
 
 TEST_P(OpConverterTest1, ConvertSqueeze) {
-  const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch);
+  const bool use_implicit_batch = (trt_mode_ == TrtTestMode::kImplicitBatch);
   // Get the NodeDef for Squeeze.
   auto get_squeeze_nodedef = [](std::vector<int> axes,
                                 DataType tf_type) -> NodeDef {
@@ -3274,7 +3277,7 @@
           {},            // input partial dims
           {2, 3},        // expected output dims
           {},            // axis
-          trt_mode == TrtTestMode::kExplicitBatch
+          trt_mode_ == TrtTestMode::kExplicitBatch
               ? Status::OK()
               : Status{error::UNIMPLEMENTED,
                        "Squeeze is not implemented for empty squeeze_dims, at "
@@ -3333,7 +3336,7 @@
              "Dimension 2 with size 2 cannot be squeezed because it must be "
              "size 1, at my_squeeze"}};
 
-  if (trt_mode == TrtTestMode::kDynamicShape) {
+  if (trt_mode_ == TrtTestMode::kDynamicShape) {
     // In this test we try to squeeze axis=2 which has size > 1. In dynamic
     // shape mode the converter sees only -1, so it cannot catch this error.
     squeeze_non_singleton.status = Status::OK();  // conversion status
@@ -3348,7 +3351,7 @@
   for (TestParamBase p : test_params) {
     SCOPED_TRACE(p);
     Reset();
-    NodeDef node_def = get_squeeze_nodedef(p.param, tf_type);
+    NodeDef node_def = get_squeeze_nodedef(p.param, tf_type_);
     AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6},
                   p.partial_input_dims);
     TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status,
@@ -4103,14 +4106,14 @@
 
 TEST_P(OpConverterTest1, ConvertConv2D) {
   // Get nodedef for Conv2D layer.
-  DataType tf_type_loc = tf_type;
+  DataType tf_type = tf_type_;
   auto get_conv2d_nodedef =
-      [tf_type_loc](std::vector<int> strides = {1, 1, 1, 1},
-                    string padding = "SAME", string data_format = "NCHW",
-                    std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
+      [tf_type](std::vector<int> strides = {1, 1, 1, 1},
+                string padding = "SAME", string data_format = "NCHW",
+                std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
     Scope s = Scope::NewRootScope();
-    auto input = ops::Placeholder(s.WithOpName("input"), tf_type_loc);
-    auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type_loc);
+    auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
+    auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type);
     ops::Conv2D::Attrs attrs =
         ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
     auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
@@ -4203,12 +4206,12 @@
         node_def, error::UNIMPLEMENTED,
         "Stride must be 1 for batch and channel dimensions, at my_conv2d");
   }
-  if (trt_mode == TrtTestMode::kDynamicShape) {
+  if (trt_mode_ == TrtTestMode::kDynamicShape) {
     Reset();
     NodeDef node_def = get_conv2d_nodedef();
     // Channel dim unknown, should fail.
     AddTestTensorWithTFDims("input", {-1, -1, -1, -1},
-                            TfDataTypeToTrt(tf_type));
+                            TfDataTypeToTrt(tf_type_));
     AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
     RunValidationAndConversion(
         node_def, error::INVALID_ARGUMENT,
@@ -4230,8 +4233,6 @@
 
   // Ok.
   std::vector<TestParams> ok_params = {
-// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x.
-#if !IS_TRT_VERSION_GE(7, 1, 3, 0)
     // Basic
     TestParams{/*input_dims=*/{1, 1, 2, 3},
                /*input=*/{0, 1, 2, 3, 3, 4},
@@ -4243,9 +4244,6 @@
                /*dilations=*/{1, 1, 1, 1},
                /*expected_output_dims=*/{1, 1, 2, 2},
                /*expected_output=*/{1, 1, 0, 1}},
-#endif
-// TODO(b/162448349): Enable the test parameters for TRT 7.1.3.x.
-#if !IS_TRT_VERSION_GE(7, 1, 3, 0)
     // SAME padding (Asymmetric)
     TestParams{/*input_dims=*/{1, 1, 2, 3},
                /*input=*/{0, 1, 2, 3, 3, 4},
@@ -4268,9 +4266,6 @@
                /*dilations=*/{1, 1, 1, 1},
                /*expected_output_dims=*/{1, 1, 2, 3},
                /*expected_output=*/{1, 2, -1, 3, 1, -3}},
-#endif
-// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x.
-#if !IS_TRT_VERSION_GE(7, 1, 3, 0)
     // NHWC
     TestParams{/*input_dims=*/{1, 2, 3, 1},
                /*input=*/{0, 1, 2, 3, 3, 4},
@@ -4304,7 +4299,6 @@
                /*dilations=*/{1, 1, 1, 1},
                /*expected_output_dims=*/{1, 1, 2, 2},
                /*expected_output=*/{1, 0, 1, 3}},
-#endif
   };
 
   for (int i = 0; i < ok_params.size(); i++) {
@@ -4313,15 +4307,15 @@
         get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding,
                            ok_params[i].data_format, ok_params[i].dilations);
     std::vector<int> partial_input_shape;
-    if (trt_mode == TrtTestMode::kDynamicShape) {
+    if (trt_mode_ == TrtTestMode::kDynamicShape) {
       // The channel dim cannot have unknown size, fix that.
       partial_input_shape.resize(ok_params[i].input_dims.size(), -1);
       int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3;
       partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id];
     }
 
-    AddTestTensor("input", ok_params[i].input_dims, tf_type, ok_params[i].input,
-                  partial_input_shape);
+    AddTestTensor("input", ok_params[i].input_dims, tf_type_,
+                  ok_params[i].input, partial_input_shape);
     AddTestWeights<float>("weights", ok_params[i].filter_dims,
                           ok_params[i].filter);
 
@@ -4848,7 +4842,7 @@
   for (int nDim : test_nDims) {
     // Input is weights, should fail.
     Reset();
-    NodeDef node_def = get_pool_nodedef(tf_type, nDim);
+    NodeDef node_def = get_pool_nodedef(tf_type_, nDim);
 
     AddTestWeights<float>("input", {1, 1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
@@ -4957,7 +4951,7 @@
       for (bool is_max_pooling : {true, false}) {
         Reset();
         NodeDef node_def =
-            get_pool_nodedef(tf_type, nDim, ksize, strides, p.padding,
+            get_pool_nodedef(tf_type_, nDim, ksize, strides, p.padding,
                              data_format, is_max_pooling);
         AddTestTensor("input", input_dims, input);
         TestOpConverter("my_pool", node_def, expected_output_dims, Status::OK(),
@@ -5019,7 +5013,7 @@
 TEST_P(OpConverterTest3, ConvertGather) {
   // Get the NodeDef for GatherV2.
   Scope s = Scope::NewRootScope();
-  auto params = ops::Placeholder(s.WithOpName("params"), tf_type);
+  auto params = ops::Placeholder(s.WithOpName("params"), tf_type_);
   auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
   auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32);
   auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis);
@@ -5027,7 +5021,7 @@
   {
     // Axis is a tensor, should fail.
     Reset();
-    AddTestTensor("params", {1, 1, 2, 3}, tf_type, {});
+    AddTestTensor("params", {1, 1, 2, 3}, tf_type_, {});
     AddTestTensor("indices", {1, 2}, DT_INT32, {});
     AddTestTensor("axis", {1}, DT_INT32, {});
     RunValidationAndConversion(
@@ -5072,7 +5066,7 @@
                  /*expected_output_shape=*/{2, 1, 1, 3},
                  /*expected_output=*/{4, 5, 6, 1, 2, 3},
                  /*params_is_tensor=*/true,
-                 trt_mode == TrtTestMode::kImplicitBatch
+                 trt_mode_ == TrtTestMode::kImplicitBatch
                      ? Status{error::UNIMPLEMENTED,
                               "TensorRT does not allow manipulation of the"
                               " batch dimension, at my_gather"}
@@ -5085,7 +5079,7 @@
                  /*expected_output_shape=*/{2, 1, 2, 1},
                  /*expected_output=*/{3, 1, 6, 4},
                  /*params_is_tensor=*/true,
-                 trt_mode == TrtTestMode::kImplicitBatch
+                 trt_mode_ == TrtTestMode::kImplicitBatch
                      ? Status{error::UNIMPLEMENTED,
                               "Indices must have a batch size of 1 when params"
                               " is a tensor."}
@@ -5099,7 +5093,7 @@
                  /*expected_output_shape=*/{2, 1, 2},
                  /*expected_output=*/{2, 3, 5, 6},
                  /*params_is_tensor=*/false,
-                 trt_mode == TrtTestMode::kImplicitBatch
+                 trt_mode_ == TrtTestMode::kImplicitBatch
                      ? Status{error::UNIMPLEMENTED,
                               "The input axis must be zero when params is a"
                               " weight."}
@@ -5112,13 +5106,13 @@
                  /*expected_output_shape=*/{2},
                  /*expected_output=*/{2, 4},
                  /*params_is_tensor=*/true,
-                 trt_mode == TrtTestMode::kImplicitBatch  // conversion_status
+                 trt_mode_ == TrtTestMode::kImplicitBatch  // conversion_status
                      ? Status{error::UNIMPLEMENTED,
                               "TensorRT does not allow manipulation of the "
                               "batch dimension, at my_gather"}
                      : Status::OK(),
-                 Status::OK(),                            // runtime_status
-                 trt_mode == TrtTestMode::kImplicitBatch  // add_index_status
+                 Status::OK(),                             // runtime_status
+                 trt_mode_ == TrtTestMode::kImplicitBatch  // add_index_status
                      ? Status{error::INVALID_ARGUMENT,
                               "Batch size doesn't match for tensor indices: "
                               "Provided batch size does not match converter "
@@ -5233,7 +5227,7 @@
     if (p.params_is_tensor) {
       AddTestTensor("params", p.params_shape, params_input);
     } else {
-      AddTestWeights("params", p.params_shape, params_input, tf_type);
+      AddTestWeights("params", p.params_shape, params_input, tf_type_);
     }
     AddTestTensor("indices", p.indices_shape, DT_INT32, p.indices, {},
                   p.add_index_status);
@@ -5273,7 +5267,7 @@
   {
     // Input is weights, should fail.
     Reset();
-    const NodeDef node_def = CreateReduceOp<ops::Sum>(tf_type, false);
+    const NodeDef node_def = CreateReduceOp<ops::Sum>(tf_type_, false);
     AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
     AddTestWeights<int32>("axis", {1}, {1});
     RunValidationAndConversion(
@@ -5283,7 +5277,7 @@
   {
     // Axis is weights, should fail.
     Reset();
-    const NodeDef node_def = CreateReduceOp<ops::Sum>(tf_type, false);
+    const NodeDef node_def = CreateReduceOp<ops::Sum>(tf_type_, false);
     AddTestTensor("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
     AddTestTensor("axis", {1}, DT_INT32, {1});
     RunValidationAndConversion(
@@ -5343,7 +5337,7 @@
       for (auto p : params) {
         SCOPED_TRACE(StrCat(op.name, keep_dims ? "keep_dims" : ""));
         Reset();
-        NodeDef node_def = op.get_node(tf_type, keep_dims);
+        NodeDef node_def = op.get_node(tf_type_, keep_dims);
 
         AddTestTensor("input", p.input_dims, p.input_values);
         AddTestWeights<int32>("axis", {static_cast<int>(p.axis.size())},
@@ -5363,7 +5357,7 @@
             int ax_positive = ax >= 0 ? ax : ax + rank;
             // Zero marks elements that we will remove later.
             expected_output_dims[ax_positive] = keep_dims ? 1 : 0;
-            if (trt_mode == TrtTestMode::kImplicitBatch &&
+            if (trt_mode_ == TrtTestMode::kImplicitBatch &&
                 (ax == 0 || ax == -rank)) {
               p.conversion_status = errors::Unimplemented(
                   "TensorRT does not allow manipulation of the batch "
@@ -5399,7 +5393,7 @@
   {
     // Input is weights, should fail.
     Reset();
-    const NodeDef node_def = CreateUnaryOp<ops::Neg>(tf_type);
+    const NodeDef node_def = CreateUnaryOp<ops::Neg>(tf_type_);
     AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
     RunValidationAndConversion(
         node_def, error::UNIMPLEMENTED,
@@ -5455,7 +5449,7 @@
     if (!op_map.count(op_name)) {
       FAIL() << "Unary op test map does not contain op " << op_name;
     }
-    NodeDef node_def = op_map[op_name].first(tf_type);
+    NodeDef node_def = op_map[op_name].first(tf_type_);
 
     // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for
     // now. Need to find a better way to express input and output types.
@@ -5463,7 +5457,7 @@
     // TODO(tfeher): improve tests by defining an expected output data type and
     // check that. Currently only the shape and values of the output are
     // checked.
-    DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type;
+    DataType input_tf_type = op_name == "Cast" ? DT_HALF : tf_type_;
 
     std::vector<float> input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f};
     AddTestTensor("input", p.input_dims, input_tf_type, input_values);
@@ -6030,7 +6024,7 @@
        /*axis=*/1,
        /*expected_output_dims=*/{1, 2, 2, 3},
        /*expected_output=*/InitTestVector<float>(12),
-       trt_mode == TrtTestMode::kImplicitBatch
+       trt_mode_ == TrtTestMode::kImplicitBatch
            ? Status{error::UNIMPLEMENTED,
                     "The input \"values_1\" for Pack must be a tensor, at "
                     "my_pack"}
@@ -6056,7 +6050,7 @@
        /*axis=*/-4,
        /*expected_output_dims=*/{2, 1, 2, 3},
        /*expected_output=*/InitTestVector<float>(12),
-       trt_mode == TrtTestMode::kImplicitBatch
+       trt_mode_ == TrtTestMode::kImplicitBatch
            ? Status{error::UNIMPLEMENTED,
                     "TensorRT does not allow manipulation of the batch "
                     "dimension, at my_pack"}
@@ -6116,7 +6110,7 @@
       },
   };
   // Inputs have inconsistent shapes, should fail.
-  if (trt_mode != TrtTestMode::kDynamicShape) {
+  if (trt_mode_ != TrtTestMode::kDynamicShape) {
     params.push_back(TestParams{
         /*input_shapes=*/{{1, 2, 3}, {1, 3, 2}},
         /*partial_input_shapes=*/{{}, {}},
@@ -6136,7 +6130,7 @@
     // TODO(tfeher) Add dynamic shapes test once TRT handles shape error
     // decently
   }
-  if (trt_mode == TrtTestMode::kDynamicShape) {
+  if (trt_mode_ == TrtTestMode::kDynamicShape) {
     // Test with mixed dynamic / static shape input tensors
     params.push_back(
         TestParams{/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
@@ -6152,14 +6146,14 @@
     const int num_inputs = p.input_shapes.size();
     EXPECT_EQ(num_inputs, p.input_values.size());
 
-    NodeDef node_def = GetPackNodeDef(tf_type, num_inputs, p.axis);
+    NodeDef node_def = GetPackNodeDef(tf_type_, num_inputs, p.axis);
     // Create inputs.
     for (int j = 0; j < num_inputs; ++j) {
       if (j == 1 && p.input_1_is_weight) {
         AddTestWeights(StrCat("values_", j), p.input_shapes[j],
-                       p.input_values[j], tf_type);
+                       p.input_values[j], tf_type_);
       } else {
-        AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type,
+        AddTestTensor(StrCat("values_", j), p.input_shapes[j], tf_type_,
                       p.input_values[j], p.partial_input_shapes[j]);
       }
     }
@@ -6687,7 +6681,7 @@
   {
     // Input is a weight, should fail.
     Reset();
-    NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type);
+    NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_);
     AddTestWeights<float>("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
     AddTestTensor("y", {1, 1, 2, 3});
     RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
@@ -6714,7 +6708,7 @@
        /*value_y=*/std::vector<float>(7 * 5, 0),
        /*expected_output_dims=*/{1, 1, 2, 3},
        /*expected_output=*/common_input,
-       trt_mode == TrtTestMode::kDynamicShape
+       trt_mode_ == TrtTestMode::kDynamicShape
            ? Status::OK()
            : errors::InvalidArgument("Infeasible broadcast scheme"),
        errors::Internal(
@@ -6740,7 +6734,7 @@
 
   for (auto p : params) {
     Reset();
-    NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type);
+    NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type_);
     AddTestTensor("x", p.dims_x, p.value_x);
     AddTestTensor("y", p.dims_y, p.value_y);
     TestOpConverter("my_squared_diff", node_def, p.expected_output_dims,
@@ -6776,7 +6770,7 @@
 void TestConvertResize(OpConverterTest* test) {
   typedef typename EnumToDataType<dtype>::Type CType;
 
-  std::vector<ResizeTestParams<CType>> params{
+  std::vector<ResizeTestParams<CType>> params {
 // TODO(b/162442839): Enable the test parameters for TRT 7.1.3.x.
 #if !IS_TRT_VERSION_GE(7, 1, 3, 0)
     {
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
index a699600..1fc0d13 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
@@ -241,36 +241,6 @@
 
 #endif
 
-string GetLinkedTensorRTVersion() {
-  int major, minor, patch;
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
-  major = NV_TENSORRT_MAJOR;
-  minor = NV_TENSORRT_MINOR;
-  patch = NV_TENSORRT_PATCH;
-#else
-  major = 0;
-  minor = 0;
-  patch = 0;
-#endif
-  return absl::StrCat(major, ".", minor, ".", patch);
-}
-
-string GetLoadedTensorRTVersion() {
-  int major, minor, patch;
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
-  int ver = getInferLibVersion();
-  major = ver / 1000;
-  ver = ver - major * 1000;
-  minor = ver / 100;
-  patch = ver - minor * 100;
-#else
-  major = 0;
-  minor = 0;
-  patch = 0;
-#endif
-  return absl::StrCat(major, ".", minor, ".", patch);
-}
-
 absl::string_view GetDeviceName(const Node* node) {
   if (node->has_assigned_device_name()) {
     return node->assigned_device_name();
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h
index a0505c3..7570dff 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h
@@ -117,14 +117,6 @@
 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
 
-// Returns a string that includes compile time TensorRT library version
-// information {Maj, Min, Patch}.
-string GetLinkedTensorRTVersion();
-
-// Returns a string that includes runtime time TensorRT library version
-// information {Maj, Min, Patch}.
-string GetLoadedTensorRTVersion();
-
 // Returns true if an engine built for cached_shapes can also run actual_shapes.
 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
                          const std::vector<TensorShape>& cached_shapes);
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index 58d1c61..5b2ae82 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -800,6 +800,9 @@
 
     TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
     infer->setGpuAllocator(allocator);
+    // Need to initialize plugins in order to deserialize engines that contain
+    // plugins.
+    MaybeInitializeTrtPlugins(&logger);
     TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
         infer->deserializeCudaEngine(serialized_segment_.c_str(),
                                      serialized_segment_.size(), nullptr));
diff --git a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc b/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc
deleted file mode 100644
index 141a7d1..0000000
--- a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc
+++ /dev/null
@@ -1,236 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "absl/strings/str_cat.h"
-#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h"
-#include "tensorflow/core/platform/logging.h"
-
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
-#define EIGEN_USE_GPU  // For definition of Eigen::GpuDevice.
-#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
-#include "tensorflow/core/util/gpu_kernel_helper.h"
-#include "third_party/tensorrt/NvInfer.h"
-
-namespace tensorflow {
-namespace tensorrt {
-using nvinfer1::DataType;
-using nvinfer1::Dims;
-using nvinfer1::IPluginCreator;
-using nvinfer1::IPluginV2;
-using nvinfer1::IPluginV2Ext;
-using nvinfer1::PluginField;
-using nvinfer1::PluginFieldCollection;
-using nvinfer1::PluginFieldType;
-using nvinfer1::PluginFormat;
-
-template <typename SrcT, typename DstT>
-__global__ void Cast(const SrcT* input, int num_elements, DstT* output) {
-  for (int i : CudaGridRangeX(num_elements)) {
-    output[i] = static_cast<DstT>(input[i]);
-  }
-}
-
-template <typename SrcT, typename DstT>
-void RunCast(const SrcT* d_input, int num_elements, DstT* d_output,
-             cudaStream_t stream) {
-  const int threads_per_block = 256;
-  const int blocks_per_grid =
-      (num_elements + threads_per_block - 1) / threads_per_block;
-  TF_CHECK_OK(CudaLaunchKernel(Cast<SrcT, DstT>, threads_per_block,
-                               blocks_per_grid, 0, stream, d_input,
-                               num_elements, d_output));
-}
-
-const char* kPluginName = "TfTrtPluginCast";
-
-class CastPlugin : public TrtPlugin {
- public:
-  CastPlugin(DataType src_type, DataType dst_type)
-      : src_type_(src_type), dst_type_(dst_type) {}
-
-  CastPlugin(const void* serialized_data, size_t length)
-      : TrtPlugin(serialized_data, length) {
-    const char* buffer = static_cast<const char*>(serialized_data);
-    src_type_ = ReadFromBuffer<DataType>(&buffer);
-    dst_type_ = ReadFromBuffer<DataType>(&buffer);
-    src_dims_ = ReadFromBuffer<Dims>(&buffer);
-  }
-
-  CastPlugin(const CastPlugin& rhs)
-      : TrtPlugin(rhs),
-        src_type_(rhs.src_type_),
-        dst_type_(rhs.dst_type_),
-        src_dims_(rhs.src_dims_) {}
-
-  // Methods from IPluginV2Ext.
-
-  DataType getOutputDataType(int index, const DataType* input_types,
-                             int num_inputs) const override {
-    DCHECK_EQ(0, index);
-    DCHECK_EQ(1, num_inputs);
-    return dst_type_;
-  }
-
-  bool isOutputBroadcastAcrossBatch(int output_index,
-                                    const bool* input_is_broadcasted,
-                                    int num_inputs) const override {
-    return false;
-  }
-
-  bool canBroadcastInputAcrossBatch(int input_index) const override {
-    return false;
-  }
-
-  void configurePlugin(const Dims* input_dims, int num_inputs,
-                       const Dims* output_dims, int num_outputs,
-                       const DataType* input_types,
-                       const DataType* output_types,
-                       const bool* input_is_broadcast,
-                       const bool* output_is_broadcast,
-                       PluginFormat float_format, int max_batch_size) override {
-    DCHECK_EQ(1, num_inputs);
-    DCHECK_EQ(1, num_outputs);
-    DCHECK(src_type_ == input_types[0]);
-    DCHECK(dst_type_ == output_types[0]);
-    src_dims_ = input_dims[0];
-  }
-
-  IPluginV2Ext* clone() const override { return new CastPlugin(*this); }
-
-  // Methods from IPluginV2.
-
-  const char* getPluginType() const override { return kPluginName; };
-
-  const char* getPluginVersion() const override { return kTfTrtPluginVersion; };
-
-  int getNbOutputs() const override { return 1; }
-
-  Dims getOutputDimensions(int index, const Dims* inputs,
-                           int num_input_dims) override {
-    DCHECK_EQ(0, index);
-    DCHECK_EQ(1, num_input_dims);
-    return inputs[0];
-  }
-
-  bool supportsFormat(DataType type, PluginFormat format) const override {
-    return type == DataType::kFLOAT || type == DataType::kINT32;
-  }
-
-  size_t getWorkspaceSize(int max_batch_size) const override { return 0; }
-
-  int enqueue(int batch_size, const void* const* inputs, void** outputs, void*,
-              cudaStream_t stream) override {
-    int num_elements = batch_size;
-    for (int i = 0; i < src_dims_.nbDims; i++) {
-      num_elements *= src_dims_.d[i];
-    }
-    const void* input = inputs[0];
-    void* output = outputs[0];
-    DCHECK_NE(static_cast<int>(src_type_), static_cast<int>(dst_type_));
-
-    switch (src_type_) {
-      case DataType::kFLOAT:
-        RunCast(reinterpret_cast<const float*>(input), num_elements,
-                reinterpret_cast<int32*>(output), stream);
-        break;
-      case DataType::kINT32:
-        RunCast(reinterpret_cast<const int32*>(input), num_elements,
-                reinterpret_cast<float*>(output), stream);
-        break;
-      default:
-        return 1;  // Indicates a failure.
-    }
-    return 0;
-  }
-
-  size_t getSerializationSize() const override {
-    return 2 * sizeof(DataType) + sizeof(Dims);
-  }
-
-  void serialize(void* serialized_data) const override {
-    char* buffer = static_cast<char*>(serialized_data);
-    WriteToBuffer(src_type_, &buffer);
-    WriteToBuffer(dst_type_, &buffer);
-    WriteToBuffer(src_dims_, &buffer);
-  }
-
- private:
-  DataType src_type_;
-  DataType dst_type_;
-  Dims src_dims_;
-};
-
-class CastPluginCreator : public IPluginCreator {
- public:
-  CastPluginCreator() {
-    setPluginNamespace(kTfTrtPluginNamespace);
-    plugin_fields_.emplace_back(
-        PluginField("SrcT", nullptr, PluginFieldType::kINT32, 1));
-    plugin_fields_.emplace_back(
-        PluginField("DstT", nullptr, PluginFieldType::kINT32, 1));
-
-    field_collection_.nbFields = plugin_fields_.size();
-    field_collection_.fields = plugin_fields_.data();
-  }
-
-  const char* getPluginName() const override { return kPluginName; }
-
-  const char* getPluginVersion() const override { return kTfTrtPluginVersion; }
-
-  const PluginFieldCollection* getFieldNames() override {
-    return &field_collection_;
-  }
-
-  IPluginV2* createPlugin(
-      const char* name,
-      const PluginFieldCollection* field_collection) override {
-    const PluginField* fields = field_collection->fields;
-    DataType src_type, dst_type;
-    for (int i = 0; i < field_collection->nbFields; ++i) {
-      const char* attr_name = fields[i].name;
-      if (!strcmp(attr_name, "SrcT")) {
-        src_type = *static_cast<const DataType*>(fields[i].data);
-      } else if (!strcmp(attr_name, "DstT")) {
-        dst_type = *static_cast<const DataType*>(fields[i].data);
-      } else {
-        return nullptr;
-      }
-    }
-    return new CastPlugin(src_type, dst_type);
-  }
-
-  IPluginV2* deserializePlugin(const char* name, const void* serial_data,
-                               size_t serial_len) override {
-    return new CastPlugin(serial_data, serial_len);
-  }
-
-  void setPluginNamespace(const char* plugin_namespace) override {
-    namespace_ = plugin_namespace;
-  }
-
-  const char* getPluginNamespace() const override { return namespace_.c_str(); }
-
- private:
-  PluginFieldCollection field_collection_;
-  std::vector<PluginField> plugin_fields_;
-  std::string namespace_;
-};
-
-REGISTER_TFTRT_PLUGIN(CastPluginCreator);
-
-}  // namespace tensorrt
-}  // namespace tensorflow
-
-#endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
index a8e24aa..3f8a11f 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
@@ -41,31 +41,5 @@
 #endif
 }
 
-void GetLinkedTensorRTVersion(int* major, int* minor, int* patch) {
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
-  *major = NV_TENSORRT_MAJOR;
-  *minor = NV_TENSORRT_MINOR;
-  *patch = NV_TENSORRT_PATCH;
-#else
-  *major = 0;
-  *minor = 0;
-  *patch = 0;
-#endif
-}
-
-void GetLoadedTensorRTVersion(int* major, int* minor, int* patch) {
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
-  int ver = getInferLibVersion();
-  *major = ver / 1000;
-  ver = ver - *major * 1000;
-  *minor = ver / 100;
-  *patch = ver - *minor * 100;
-#else
-  *major = 0;
-  *minor = 0;
-  *patch = 0;
-#endif
-}
-
 }  // namespace tensorrt
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.h b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h
index f52bb6f..9b24eb3 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.h
+++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h
@@ -21,12 +21,6 @@
 
 bool IsGoogleTensorRTEnabled();
 
-// Return compile time TensorRT library version information {Maj, Min, Patch}.
-void GetLinkedTensorRTVersion(int* major, int* minor, int* patch);
-
-// Return runtime time TensorRT library version information {Maj, Min, Patch}.
-void GetLoadedTensorRTVersion(int* major, int* minor, int* patch);
-
 }  // namespace tensorrt
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc
index 03f77c6..52252f1 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc
@@ -16,18 +16,15 @@
 #include <tuple>
 
 #include "pybind11/pybind11.h"
+#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
 #include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h"
 
 std::tuple<int, int, int> get_linked_tensorrt_version() {
-  int major, minor, patch;
-  tensorflow::tensorrt::GetLinkedTensorRTVersion(&major, &minor, &patch);
-  return std::tuple<int, int, int>{major, minor, patch};
+  return tensorflow::tensorrt::GetLinkedTensorRTVersion();
 }
 
 std::tuple<int, int, int> get_loaded_tensorrt_version() {
-  int major, minor, patch;
-  tensorflow::tensorrt::GetLoadedTensorRTVersion(&major, &minor, &patch);
-  return std::tuple<int, int, int>{major, minor, patch};
+  return tensorflow::tensorrt::GetLoadedTensorRTVersion();
 }
 
 PYBIND11_MODULE(_pywrap_py_utils, m) {
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index ac999d8..e9bcbcc 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -337,7 +337,6 @@
     visibility = [":friends"],
     deps = [
         ":common",
-        ":frontend_attributes_util",
         ":host_compute_metadata_proto_cc",
         ":rearrange_function_argument",
         ":sharding_util",
@@ -353,23 +352,16 @@
         "//tensorflow/compiler/jit:common",
         "//tensorflow/compiler/jit:flags",
         "//tensorflow/compiler/jit:shape_inference",
-        "//tensorflow/compiler/jit:xla_cluster_util",
         "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
-        "//tensorflow/compiler/tf2xla/lib:util",
-        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:protobuf_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
-        "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
-        "//tensorflow/compiler/xla/client",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/compiler/xla/client:xla_computation",
-        "//tensorflow/compiler/xla/client/lib:arithmetic",
-        "//tensorflow/compiler/xla/client/lib:constants",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
@@ -378,11 +370,8 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:ops",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:stream_executor_no_cuda",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
         "@com_google_absl//absl/types:variant",
     ],
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index d7a8e67..807c061 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/tf2xla/lib/broadcast.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -28,13 +29,26 @@
       : XlaOpKernel(context) {}
 
   void Compile(XlaOpKernelContext* context) override {
-    const TensorShape input_shape = context->InputShape(0);
     TensorShape output_shape;
     OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+    auto output_status_or =
+        BroadcastTo(context->Input(0), output_shape.dim_sizes());
+    OP_REQUIRES_OK(context, output_status_or.status());
+    auto output = output_status_or.ValueOrDie();
+    std::vector<bool> dynamic_dims;
+    OP_REQUIRES_OK(
+        context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
+    for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) {
+      if (dynamic_dims[dim]) {
+        output = xla::SetDimensionSize(
+            output,
+            xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}),
+                         {}),
+            dim);
+      }
+    }
 
-    auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
-    OP_REQUIRES_OK(context, output.status());
-    context->SetOutput(0, output.ValueOrDie());
+    context->SetOutput(0, output);
   }
 };
 
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index 97359f8..d63b814 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -74,12 +74,44 @@
                                   " vs. ", indices_shape.dim_size(d)));
     }
     xla::XlaBuilder* builder = ctx->builder();
+    // data shape = [indices_shape, segment_shape]
+    // buffer shape = [num_segment, segment_shape]
+    // We now create the buffer shape by reverse enginerring data shape into
+    // indices shape and segment shape.
     TensorShape buffer_shape = data_shape;
     buffer_shape.RemoveDimRange(0, indices_shape.dims());
     buffer_shape.InsertDim(0, num_segments);
+
     auto buffer =
         xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
 
+    // Build dynamic dim sizes for buffer, as well as whether each dimension
+    // size is dynamic or static. We build two parts: num_sgement part and
+    // segment_shape part.
+    std::vector<xla::XlaOp> buffer_dims;
+    std::vector<bool> buffer_dims_are_dynamic;
+    // Build the "num_segment" part.
+    bool num_segments_is_dynamic;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic));
+
+    buffer_dims.insert(buffer_dims.begin(), ctx->Input(2));
+    buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(),
+                                   num_segments_is_dynamic);
+    // Build the segment shape part.
+    for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) {
+      buffer_dims.push_back(xla::GetDimensionSize(data, i));
+      buffer_dims_are_dynamic.push_back(
+          ctx->InputXlaShape(0)->is_dynamic_dimension(i));
+    }
+
+    for (int64 i = 0; i < buffer_dims.size(); ++i) {
+      if (buffer_dims_are_dynamic[i]) {
+        // For each dynamic dimension, call set-dimension-size on it.
+        buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i);
+      }
+    }
+
     auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
                            xla::XlaBuilder* builder) { return Combine(a, b); };
 
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 784b790..268317d 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -15,6 +15,9 @@
 
 #include "tensorflow/core/util/strided_slice_op.h"
 
+#include <vector>
+
+#include "absl/algorithm/container.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
@@ -23,6 +26,7 @@
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/ops_util.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -33,6 +37,7 @@
 
 namespace tensorflow {
 namespace {
+using errors::InvalidArgument;
 
 class StridedSliceOp : public XlaOpKernel {
  public:
@@ -48,7 +53,6 @@
   void Compile(XlaOpKernelContext* ctx) override {
     const TensorShape input_shape = ctx->InputShape(0);
     const TensorShape begin_shape = ctx->InputShape("begin");
-
     OP_REQUIRES(
         ctx, begin_shape.dims() == 1,
         errors::InvalidArgument("'begin' input has to be a rank 1 vector"));
@@ -78,20 +82,24 @@
     TensorShape final_shape;
     PartialTensorShape dummy_processing_shape, partial_final_shape;
     bool dummy = false;
-    OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
-                            begin_is_constant ? &begin_tensor : nullptr,
-                            end_is_constant ? &end_tensor : nullptr,
-                            strides_tensor, input_shape, begin_mask_, end_mask_,
-                            ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
-                            &dummy_processing_shape, &partial_final_shape,
-                            &dummy, &dummy, &dummy, &begin, &end, &strides));
+    absl::InlinedVector<int64, 4> output_to_sparse_mapping;
+    absl::InlinedVector<int64, 4> output_to_processing_mapping;
+    OP_REQUIRES_OK(
+        ctx,
+        ValidateStridedSliceOp(
+            begin_is_constant ? &begin_tensor : nullptr,
+            end_is_constant ? &end_tensor : nullptr, strides_tensor,
+            input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
+            shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
+            &dummy, &dummy, &dummy, &begin, &end, &strides,
+            &output_to_sparse_mapping, &output_to_processing_mapping));
 
-    OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape),
-                errors::InvalidArgument(
-                    "XLA can't deduce compile time constant output "
-                    "shape for strided slice: ",
-                    partial_final_shape.DebugString(),
-                    ", output shape must be a compile-time constant"));
+    OP_REQUIRES(
+        ctx, partial_final_shape.AsTensorShape(&final_shape),
+        InvalidArgument("XLA can't deduce compile time constant output "
+                        "shape for strided slice: ",
+                        partial_final_shape.DebugString(),
+                        ", output shape must be a compile-time constant"));
 
     xla::XlaOp slice = ctx->Input(0);
     if (begin_is_constant && end_is_constant) {
@@ -119,69 +127,84 @@
       auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
       OP_REQUIRES_OK(ctx, operand_shape_or.status());
       xla::Shape xla_shape = operand_shape_or.ValueOrDie();
-      if (xla_shape.is_static()) {
-        // Static output shape, return a static slice.
-        slice = xla::Reshape(slice, final_shape.dim_sizes());
+      std::vector<bool> begins_are_dynamic;
+      OP_REQUIRES_OK(
+          ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
+      std::vector<bool> ends_are_dynamic;
+      OP_REQUIRES_OK(
+          ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
+      bool begins_are_static = absl::c_all_of(
+          begins_are_dynamic, [](bool dynamic) { return !dynamic; });
+      OP_REQUIRES(ctx, begins_are_static,
+                  errors::InvalidArgument(
+                      "XLA can't use dynamic begin values for slice."));
+      bool ends_are_static = absl::c_all_of(
+          ends_are_dynamic, [](bool dynamic) { return !dynamic; });
+      // Static output shape, return a static slice.
+      slice = xla::Reshape(slice, final_shape.dim_sizes());
+      if (xla_shape.is_static() && ends_are_static) {
         ctx->SetOutput(0, slice);
         return;
       }
-      auto input_dim_sizes = input_shape.dim_sizes();
 
-      for (int64 i = 0; i < xla_shape.rank(); ++i) {
-        if (xla_shape.is_dynamic_dimension(i)) {
-          input_dim_sizes[i] = -1;
+      for (int64 i = 0; i < final_shape.dims(); ++i) {
+        int64 input_index = output_to_processing_mapping[i];
+        if (input_index == -1) {
+          continue;
         }
-      }
-      PartialTensorShape input_partial_shape(input_dim_sizes);
-      partial_final_shape.Clear();
-      end.clear();
-      strides.clear();
-      begin.clear();
-      // Run shape inferenference again with partial shape.
-      OP_REQUIRES_OK(ctx, ValidateStridedSliceOp(
-                              &begin_tensor, &end_tensor, strides_tensor,
-                              input_partial_shape, begin_mask_, end_mask_,
-                              ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
-                              &dummy_processing_shape, &partial_final_shape,
-                              &dummy, &dummy, &dummy, &begin, &end, &strides));
-      if (partial_final_shape.AsTensorShape(&final_shape)) {
-        // Static output shape, return a static slice.
-        slice = xla::Reshape(slice, final_shape.dim_sizes());
-        ctx->SetOutput(0, slice);
-        return;
-      }
+        bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
 
-      // We consider slicing a dynamic tensor t with negative indices as a
-      // dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n
-      for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
-        bool dynamic_dim = partial_final_shape.dim_size(i) - 1;
-        bool backward_slice = end[i] < 0;
-        if (dynamic_dim && backward_slice) {
+        int64 sparse_index = output_to_sparse_mapping[i];
+        bool end_is_dynamic =
+            sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
+        bool backward_slice = sparse_index == -1
+                                  ? false
+                                  : end_literal.Get<int32>({sparse_index}) < 0;
+        if ((input_is_dynamic && backward_slice) || end_is_dynamic) {
           OP_REQUIRES(
-              ctx, strides[i] == 1,
+              ctx, strides[input_index] == 1,
               errors::InvalidArgument("XLA has not implemented dynamic "
                                       "sized slice with non-trival stride yet. "
                                       "Please file a bug against XLA"));
-
-          OP_REQUIRES(ctx, begin[i] >= 0,
-                      errors::InvalidArgument(
-                          "XLA has not implemented dynamic "
-                          "sized slice with negative begin index %lld. "
-                          "Please file a bug against XLA",
-                          begin[i]));
           // If there is a dynamic dimension, properly set dimension size of
           // the result.
-          auto operand_size = xla::GetDimensionSize(ctx->Input(0), i);
-
-          operand_size = xla::Add(
-              operand_size, xla::ConstantR0<int32>(ctx->builder(), end[i]));
+          auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index);
+          if (backward_slice) {
+            // We consider slicing a dynamic tensor t with negative indices as
+            // a dynamic sized slice. E.g., t[: -n], the result length is
+            // shape(t) - n.
+            OP_REQUIRES(ctx, !end_is_dynamic,
+                        errors::InvalidArgument(
+                            "XLA has not implemented dynamic "
+                            "sized slice with dynamic negative index %lld. "));
+            operand_size = xla::Add(
+                operand_size,
+                xla::ConstantR0<int32>(ctx->builder(),
+                                       end_literal.Get<int32>({sparse_index})));
+          } else {
+            // The end of slice with dynamic slice size is the min of operand
+            // shape and slice size. E.g., t[:end_size], result size is
+            // min(shape(t), end_size).
+            xla::XlaOp end_size;
+            if (end_is_dynamic) {
+              end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index},
+                                                 {sparse_index + 1}, {1}),
+                                      {});
+            } else {
+              end_size =
+                  xla::ConstantR0<int32>(ctx->builder(), end[input_index]);
+            }
+            operand_size = xla::Min(operand_size, end_size);
+          }
           slice = xla::SetDimensionSize(
               slice,
-              xla::Sub(operand_size,
-                       xla::ConstantR0<int32>(ctx->builder(), begin[i])),
+              xla::Sub(operand_size, xla::ConstantR0<int32>(
+                                         ctx->builder(), begin[input_index])),
               i);
         }
       }
+      ctx->SetOutput(0, slice);
+      return;
     } else {
       // When output shape is fully defined, it must be a size one slice:
       //
@@ -239,9 +262,9 @@
 
       std::vector<int64> output_shape_dim_sizes;
       slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
+      slice = xla::Reshape(slice, final_shape.dim_sizes());
+      ctx->SetOutput(0, slice);
     }
-    slice = xla::Reshape(slice, final_shape.dim_sizes());
-    ctx->SetOutput(0, slice);
   }
 
  private:
diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
index abaeb30..db1a692 100644
--- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc
@@ -152,6 +152,7 @@
 
   RegisterDialects();
   mlir::MLIRContext context;
+  context.loadAllGloballyRegisteredDialects();
   TF_ASSIGN_OR_RETURN(
       mlir::OwningModuleRef module,
       ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context));
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 242a2b0..3cf9df6 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -137,7 +137,6 @@
       const auto& it = node.attr().find("allowed_devices");
       if (it != node.attr().end()) {
         if (!it->second.list().s().empty()) {
-          // TODO(b/149512838): Support non-empty allowed devices.
           return errors::InvalidArgument(
               "VarHandleOp with non-empty allowed devices is not supported.");
         }
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index df36311..4d8f6f9 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -36,6 +36,7 @@
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/common_runtime/device.h"
@@ -1142,7 +1143,11 @@
       return errors::InvalidArgument(absl::StrCat(
           "Detected unsupported operations when trying to compile graph ", name,
           " on ", device_type.type_string(), ": ", node->def().op(), " (",
-          s.error_message(), ")", FormatNodeForError(*node)));
+          s.error_message(), ")", FormatNodeForError(*node),
+          "One approach is to outside compile the unsupported ops to run on "
+          "CPUs by enabling soft placement "
+          "`tf.config.set_soft_device_placement(True)`."
+          " This has a potential performance penalty."));
     }
     return Status::OK();
   };
@@ -1357,8 +1362,15 @@
     const string& key, absl::Span<const DataType> types,
     absl::Span<const TensorShape> shapes) {
   if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
-    return errors::InvalidArgument(
-        "Duplicate calls to SetDeviceToHostMetadata with key ", key);
+    tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key];
+    tf2xla::HostTransferMetadata new_transfer;
+    SetTransfer(key, types, shapes, &new_transfer);
+    if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
+      return Status::OK();
+    } else {
+      return errors::InvalidArgument(
+          "Duplicate calls to SetDeviceToHostMetadata with key ", key);
+    }
   }
   tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
   SetTransfer(key, types, shapes, &transfer);
@@ -1384,8 +1396,15 @@
     const string& key, absl::Span<const DataType> types,
     absl::Span<const TensorShape> shapes) {
   if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) {
-    return errors::InvalidArgument(
-        "Duplicate calls to SetHostToDeviceMetadata with key ", key);
+    tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key];
+    tf2xla::HostTransferMetadata new_transfer;
+    SetTransfer(key, types, shapes, &new_transfer);
+    if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) {
+      return Status::OK();
+    } else {
+      return errors::InvalidArgument(
+          "Duplicate calls to SetHostToDeviceMetadata with key ", key);
+    }
   }
   tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
   SetTransfer(key, types, shapes, &transfer);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index b0d93cd..762700e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -129,8 +129,6 @@
 
     // Resource updates are converted into input / output of xla. The two
     // buffers are aliased with other if this option is true.
-    //
-    // Currently only supports TPU.
     bool alias_resource_update = false;
   };
 
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 5df508d..f348552 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -1897,5 +1897,63 @@
   EXPECT_EQ(alias.entries(0).parameter_number(), 0);
 }
 
+// Tests that passing in an exact duplicate input to SetDeviceToHostMeatadata
+// is not an error.
+TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) {
+  XlaCompiler compiler(DefaultOptions());
+
+  const string& key = "comm_key";
+  std::vector<DataType> types{DT_INT32};
+  std::vector<TensorShape> shapes{TensorShape({2})};
+
+  TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
+  TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
+}
+
+// Tests that passing in a mismatched duplicate input to
+// SetDeviceToHostMeatadata is not an error.
+TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) {
+  XlaCompiler compiler(DefaultOptions());
+
+  const string& key = "comm_key";
+  std::vector<DataType> types{DT_INT32};
+  std::vector<TensorShape> shapes{TensorShape({2})};
+  std::vector<DataType> types2{DT_FLOAT};
+  std::vector<TensorShape> shapes2{TensorShape({1})};
+
+  TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
+  Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2);
+  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
+}
+
+// Tests that passing in an exact duplicate input to SetHostToDeviceMeatadata
+// is not an error.
+TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) {
+  XlaCompiler compiler(DefaultOptions());
+
+  const string& key = "comm_key";
+  std::vector<DataType> types{DT_INT32};
+  std::vector<TensorShape> shapes{TensorShape({2})};
+
+  TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
+  TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
+}
+
+// Tests that passing in a mismatched duplicate input to
+// SetHostToDeviceMeatadata is not an error.
+TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) {
+  XlaCompiler compiler(DefaultOptions());
+
+  const string& key = "comm_key";
+  std::vector<DataType> types{DT_INT32};
+  std::vector<TensorShape> shapes{TensorShape({2})};
+  std::vector<DataType> types2{DT_FLOAT};
+  std::vector<TensorShape> shapes2{TensorShape({1})};
+
+  TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
+  Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2);
+  EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index a3c7c39..b2a1849 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -199,6 +199,7 @@
         "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
     ],
 )
@@ -304,6 +305,7 @@
         "//tensorflow/compiler/xla/tests:test_macros_header",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
+        "//tensorflow/core/platform:tf32_utils",
     ],
 )
 
@@ -344,6 +346,9 @@
     hdrs = ["sorting.h"],
     deps = [
         ":comparators",
+        ":constants",
+        ":loops",
+        ":slicing",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
@@ -575,6 +580,7 @@
         ":loops",
         ":math",
         ":matrix",
+        ":qr",
         ":slicing",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/client/lib/logdet.cc b/tensorflow/compiler/xla/client/lib/logdet.cc
index 8f37c39..18cd087 100644
--- a/tensorflow/compiler/xla/client/lib/logdet.cc
+++ b/tensorflow/compiler/xla/client/lib/logdet.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/compiler/xla/client/lib/loops.h"
 #include "tensorflow/compiler/xla/client/lib/math.h"
 #include "tensorflow/compiler/xla/client/lib/matrix.h"
+#include "tensorflow/compiler/xla/client/lib/qr.h"
 #include "tensorflow/compiler/xla/client/lib/slicing.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/literal_util.h"
@@ -33,13 +34,46 @@
 
 namespace xla {
 
-// let G = root(A) be the Cholesky root of the matrix A
-// log(det(A)) = 2*sum(log(vecdiag(G)))
+// log(det(A)) = sum(log(vecdiag(QR(A).r))), since R is triangular and Q is
+// orthonormal
 XlaOp LogDet(XlaOp a) {
-  XlaOp cholesky = Cholesky(a, /*bool lower=*/true);
+  return a.builder()->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a));
+    // Compute the number of Householder transformations required on 'a' by
+    // determining the number of rows in 'a' that are already triangular. The
+    // determinant of Q is -1 ^ (number of Householder transfomations)
+    auto rows = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32),
+                     a_shape.rank() - 2);
+    auto cols = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32),
+                     a_shape.rank() - 1);
+    auto in_lower_triangle = Lt(cols, rows);
+    auto is_zero = Eq(a, ScalarLike(a, 0));
+    auto num_zeros_in_triangle_per_row = Einsum(
+        ConvertElementType(And(in_lower_triangle, is_zero), S32), "...a->...");
+    TF_ASSIGN_OR_RETURN(auto row_shape,
+                        a.builder()->GetShape(num_zeros_in_triangle_per_row));
+    rows = Iota(a.builder(), row_shape, row_shape.rank() - 1);
+    auto num_triangle_rows =
+        Einsum(ConvertElementType(Eq(rows, num_zeros_in_triangle_per_row), S32),
+               "...a->...");
+    auto num_rows =
+        ScalarLike(num_triangle_rows, a_shape.dimensions(a_shape.rank() - 2));
 
-  return ScalarLike(a, 2) *
-         Einsum(Log(cholesky), "...aa->...", xla::PrecisionConfig::HIGHEST);
+    TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, true));
+    // Get the and log of the determinant based on the values along the diagonal
+    // of R.
+    auto log_abs_det = Einsum(Log(Abs(qr.r)), "...aa->...");
+    auto sign_diag = Reduce(
+        Sign(Einsum(qr.r, "...aa->...a")),
+        One(a.builder(), a_shape.element_type()),
+        CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()),
+        {a_shape.rank() - 2});
+    return sign_diag * log_abs_det *
+           Select(ConvertElementType(Rem(num_rows - num_triangle_rows,
+                                         ScalarLike(num_triangle_rows, 2)),
+                                     PRED),
+                  ScalarLike(sign_diag, -1.0), ScalarLike(sign_diag, 1.0));
+  });
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/logdet_test.cc b/tensorflow/compiler/xla/client/lib/logdet_test.cc
index 54af41f..319d819ed 100644
--- a/tensorflow/compiler/xla/client/lib/logdet_test.cc
+++ b/tensorflow/compiler/xla/client/lib/logdet_test.cc
@@ -51,6 +51,26 @@
                              xla::ErrorSpec(1e-4));
 }
 
+XLA_TEST_F(LogDetTest, SimpleTriangle) {
+  xla::XlaBuilder builder(TestName());
+
+  xla::Array2D<float> a_vals({
+      {4, 6, 8, 10},
+      {4, -39, 62, 73},
+      {0, 0, -146, 166},
+      {4, 6, 8, 320},
+  });
+
+  float expected = -15.9131355f;
+
+  xla::XlaOp a;
+  auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
+  xla::LogDet(a);
+
+  ComputeAndCompareR0<float>(&builder, expected, {a_data.get()},
+                             xla::ErrorSpec(1e-4));
+}
+
 XLA_TEST_F(LogDetTest, SimpleBatched) {
   xla::XlaBuilder builder(TestName());
 
diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc
index b7721f2..dbb7360 100644
--- a/tensorflow/compiler/xla/client/lib/matrix.cc
+++ b/tensorflow/compiler/xla/client/lib/matrix.cc
@@ -30,6 +30,7 @@
 #include "absl/strings/str_join.h"
 #include "absl/strings/str_split.h"
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
@@ -235,85 +236,93 @@
 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
 
 namespace {
-std::vector<int64> EinsumDiagonalLabels(absl::Span<const int64> config) {
+absl::optional<std::array<std::vector<int64>, 3>> EinsumDiagonalLabels(
+    absl::Span<const int64> config) {
   std::vector<int64> unique_labels;
+  std::vector<int64> reduce_dims;
+  std::vector<int64> broadcast_dims;
   for (auto label = config.begin(); label != config.end(); ++label) {
     auto first_label = absl::c_find(config, *label);
+    auto dim = label - config.begin();
     if (first_label == label) {
       unique_labels.push_back(*label);
+      broadcast_dims.push_back(dim);
+    } else {
+      reduce_dims.push_back(dim);
     }
   }
   if (unique_labels.size() == config.size()) {
-    unique_labels.clear();
+    return absl::nullopt;
   }
-  return unique_labels;
+  return {{unique_labels, reduce_dims, broadcast_dims}};
 }
-}  // namespace
 
-xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
+// Masks a tensor such that only the diagonal of repeated indices are non-zero.
+// The result of this can be used to create a diagonal matrix with an identity
+// reduction.
+xla::XlaOp EinsumDiagonalMask(XlaOp x, absl::Span<const int64> config) {
   XlaBuilder* builder = x.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    if (EinsumDiagonalLabels(config).empty()) {
-      return x;
-    }
     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
     Shape iota_shape = x_shape;
     iota_shape.set_element_type(S32);
     XlaOp mask = ConstantR0(builder, true);
 
-    absl::InlinedVector<int64, 8> reduce_dims;
     for (auto label = config.begin(); label != config.end(); ++label) {
       const int64 dim = label - config.begin();
       auto first_label = absl::c_find(config, *label);
-      if (first_label == label) {
-        continue;
+      if (first_label != label) {
+        const int64 first_dim = first_label - config.begin();
+        mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
+                            Iota(builder, iota_shape, dim)));
       }
-      reduce_dims.push_back(dim);
-      const int64 first_dim = first_label - config.begin();
-      mask = And(mask, Eq(Iota(builder, iota_shape, first_dim),
-                          Iota(builder, iota_shape, dim)));
     }
-    auto zero = ScalarLike(x, 0);
-    return Reduce(Select(mask, x, zero), zero,
-                  CreateScalarIdentityWithZeroComputation(
-                      x_shape.element_type(), builder),
-                  reduce_dims);
+    return Select(mask, x, ZerosLike(x));
   });
 }
 
-Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
-                                       absl::Span<const int64> y_config,
-                                       absl::Span<const int64> output_config) {
-  for (auto dim : output_config) {
-    if (absl::c_linear_search(x_config, dim) ||
-        absl::c_linear_search(y_config, dim)) {
-      if (absl::c_count(output_config, dim) > 1) {
-        return InvalidArgument("Einsum has repeated output dimension.");
-      }
-      continue;
+xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config) {
+  XlaBuilder* builder = x.builder();
+  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    auto labels = EinsumDiagonalLabels(config);
+    if (!labels) {
+      return x;
     }
-    return InvalidArgument(
-        "Einsum has output dimension without corresponding input dimension.");
-  }
-  for (auto dim : x_config) {
-    if (absl::c_linear_search(y_config, dim) ||
-        absl::c_linear_search(output_config, dim)) {
-      if (absl::c_count(x_config, dim) > 1) {
-        return InvalidArgument("Einsum has repeated lhs dimension.");
-      }
-    }
-  }
-  for (auto dim : y_config) {
-    if (absl::c_linear_search(x_config, dim) ||
-        absl::c_linear_search(output_config, dim)) {
-      if (absl::c_count(y_config, dim) > 1) {
-        return InvalidArgument("Einsum has repeated rhs dimension.");
-      }
-    }
-  }
-  return Status::OK();
+    auto zero = ScalarLike(x, 0);
+    TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
+    return Reduce(EinsumDiagonalMask(x, config), zero,
+                  CreateScalarIdentityWithZeroComputation(
+                      x_shape.element_type(), builder),
+                  labels->at(1));
+  });
 }
 
+xla::XlaOp EinsumInverseDiagonal(XlaOp x, absl::Span<const int64> config) {
+  XlaBuilder* builder = x.builder();
+  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    auto labels = EinsumDiagonalLabels(config);
+    if (!labels) {
+      return x;
+    }
+    TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
+    std::vector<int64> broadcast_sizes;
+    int64 x_dim = 0;
+    for (auto label = config.begin(); label != config.end(); ++label) {
+      auto first_label = absl::c_find(config, *label);
+      if (first_label == label) {
+        broadcast_sizes.push_back(x_shape.dimensions(x_dim));
+        ++x_dim;
+      } else {
+        broadcast_sizes.push_back(
+            broadcast_sizes[first_label - config.begin()]);
+      }
+    }
+    x = BroadcastInDim(x, broadcast_sizes, labels->at(2));
+    return EinsumDiagonalMask(x, config);
+  });
+}
+}  // namespace
+
 namespace {
 // Helper method to remove dimensions from a shape and dot dimension numbers
 // used to implement implicit broadcasting.
@@ -347,21 +356,23 @@
   XlaBuilder* builder = x.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     auto x_diagonal_labels = EinsumDiagonalLabels(x_config);
+    if (x_diagonal_labels) {
+      return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels->at(0), y,
+                    y_config, output_config, precision);
+    }
     auto y_diagonal_labels = EinsumDiagonalLabels(y_config);
-    if (!x_diagonal_labels.empty() && !y_diagonal_labels.empty()) {
-      return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels,
-                    EinsumDiagonal(y, y_config), y_diagonal_labels,
-                    output_config, precision);
-    } else if (!x_diagonal_labels.empty()) {
-      return Einsum(EinsumDiagonal(x, x_config), x_diagonal_labels, y, y_config,
-                    output_config, precision);
-    } else if (!y_diagonal_labels.empty()) {
-      return Einsum(x, x_config, EinsumDiagonal(y, y_config), y_diagonal_labels,
-                    output_config, precision);
+    if (y_diagonal_labels) {
+      return Einsum(x, x_config, EinsumDiagonal(y, y_config),
+                    y_diagonal_labels->at(0), output_config, precision);
+    }
+    auto output_diagonal_labels = EinsumDiagonalLabels(output_config);
+    if (output_diagonal_labels) {
+      return EinsumInverseDiagonal(
+          Einsum(x, x_config, y, y_config, output_diagonal_labels->at(0),
+                 precision),
+          output_config);
     }
 
-    TF_RETURN_IF_ERROR(
-        ValidateEinsumNumericDimensions(x_config, y_config, output_config));
     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
     TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
     const int64 x_rank = x_config.size();
@@ -372,41 +383,37 @@
     absl::flat_hash_set<int64> output_map;
 
     for (auto d : x_config) {
-      if (!x_map.insert(d).second) {
-        return InvalidArgument("XLA Einsum does not support rhs tracing");
-      }
+      x_map.insert(d);
     }
 
     for (auto d : y_config) {
-      if (!y_map.insert(d).second) {
-        return InvalidArgument("XLA Einsum does not support lhs tracing");
-      }
+      y_map.insert(d);
     }
 
     for (auto d : output_config) {
-      if (!output_map.insert(d).second) {
-        return InvalidArgument("XLA Einsum does not support output tracing");
-      }
+      output_map.insert(d);
     }
 
     DotDimensionNumbers dnums;
-    std::vector<int64> lhs_outer_dims;
     auto is_batch_dim = [&](int64 d) {
       return x_map.contains(d) && y_map.contains(d) && output_map.contains(d);
     };
     auto is_contracting = [&](int64 d) {
       return x_map.contains(d) && y_map.contains(d);
     };
+
     auto rhs_dimension_number = [&](int64 d) {
       return absl::c_find(y_config, d) - y_config.begin();
     };
 
     absl::InlinedVector<int64, 8> rhs_outer_dims;
+    absl::InlinedVector<int64, 8> lhs_outer_dims;
     absl::InlinedVector<int64, 8> rhs_delete_dims;
     absl::InlinedVector<int64, 8> lhs_delete_dims;
     for (int64 i = 0; i < x_rank; ++i) {
       auto dim_name = x_config[i];
       const int64 rhs_dim = rhs_dimension_number(dim_name);
+
       if (is_batch_dim(dim_name)) {
         if (x_shape.dimensions(i) == y_shape.dimensions(rhs_dim)) {
           dnums.add_lhs_batch_dimensions(i);
@@ -442,63 +449,90 @@
     }
 
     absl::c_sort(rhs_outer_dims);
-
     absl::InlinedVector<int64, 8> output_transpose_dims;
-    absl::InlinedVector<int64, 8> output_reduce_dims;
-    auto output_dimension_number = [&](int64 d) {
+
+    auto output_dimension_number = [&](int64 d) -> absl::optional<int64> {
       auto pos = absl::c_find(output_config, d);
       if (pos == output_config.end()) {
-        const int64 dim =
-            output_transpose_dims.size() + output_reduce_dims.size();
-        output_reduce_dims.push_back(dim);
-      } else {
-        output_transpose_dims.push_back(pos - output_config.begin());
+        return absl::nullopt;
       }
+      return pos - output_config.begin();
     };
 
     for (auto d : dnums.lhs_batch_dimensions()) {
-      output_dimension_number(x_config[d]);
+      output_transpose_dims.push_back(*output_dimension_number(x_config[d]));
     }
 
     for (auto d : lhs_outer_dims) {
-      output_dimension_number(x_config[d]);
+      if (auto output_dim = output_dimension_number(x_config[d])) {
+        output_transpose_dims.push_back(*output_dim);
+        continue;
+      }
+      lhs_delete_dims.push_back(d);
     }
 
     for (auto d : rhs_outer_dims) {
-      output_dimension_number(y_config[d]);
+      if (auto output_dim = output_dimension_number(y_config[d])) {
+        output_transpose_dims.push_back(*output_dim);
+        continue;
+      }
+      rhs_delete_dims.push_back(d);
     }
 
+    const int64 transpose_rank = output_transpose_dims.size();
     std::vector<int64> transpose_dims(output_rank);
-    for (int64 i = 0; i < output_rank; ++i) {
+    for (int64 i = 0; i < transpose_rank; ++i) {
       transpose_dims[output_transpose_dims[i]] = i;
     }
 
     // Remove ones that where broadcasted from the x and the y shape and adjust
     // the dimension numbers that are more minor than those dimensions.
+    absl::c_sort(lhs_delete_dims);
     DeleteDimsFromContainer(lhs_delete_dims, &x_shape,
                             dnums.mutable_lhs_batch_dimensions(),
                             dnums.mutable_lhs_contracting_dimensions());
+
+    absl::c_sort(rhs_delete_dims);
     DeleteDimsFromContainer(rhs_delete_dims, &y_shape,
                             dnums.mutable_rhs_batch_dimensions(),
                             dnums.mutable_rhs_contracting_dimensions());
     if (!lhs_delete_dims.empty()) {
-      x = Reshape(x, x_shape.dimensions());
+      x = Reduce(x, ScalarLike(x, 0),
+                 CreateScalarAddComputation(x_shape.element_type(), builder),
+                 lhs_delete_dims);
     }
 
     if (!rhs_delete_dims.empty()) {
-      y = Reshape(y, y_shape.dimensions());
+      y = Reduce(y, ScalarLike(y, 0),
+                 CreateScalarAddComputation(y_shape.element_type(), builder),
+                 rhs_delete_dims);
     }
 
     PrecisionConfig precision_proto;
     precision_proto.add_operand_precision(precision);
     precision_proto.add_operand_precision(precision);
     auto dot = DotGeneral(x, y, dnums, &precision_proto);
-    if (!output_reduce_dims.empty()) {
-      dot = Reduce(dot, ScalarLike(dot, 0),
-                   CreateScalarAddComputation(x_shape.element_type(), builder),
-                   output_reduce_dims);
+    dot = Transpose(dot, transpose_dims);
+    if (transpose_rank == output_rank) {
+      return dot;
     }
-    return Transpose(dot, transpose_dims);
+
+    auto is_output_only = [&](int64 d) {
+      return output_map.contains(d) && !x_map.contains(d) && !y_map.contains(d);
+    };
+
+    int64 dot_dim = 0;
+    std::vector<int64> new_dims;
+    new_dims.reserve(output_rank);
+    TF_ASSIGN_OR_RETURN(Shape dot_shape, builder->GetShape(dot));
+    for (auto d : output_config) {
+      if (is_output_only(d)) {
+        new_dims.push_back(1);
+      } else {
+        new_dims.push_back(dot_shape.dimensions(dot_dim));
+      }
+    }
+    return Reshape(dot, new_dims);
   });
 }
 
diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h
index 46f70ed..1a9f72d 100644
--- a/tensorflow/compiler/xla/client/lib/matrix.h
+++ b/tensorflow/compiler/xla/client/lib/matrix.h
@@ -112,14 +112,6 @@
 // Returns an empty string if the einsum string already has an ->.
 std::string NormalizeEinsumString(absl::string_view einsum_config);
 
-// Determine if each dimension label is in at least two inputs.
-//
-// NOTE: This function is meant for testing, there is no need to call it
-// directly.
-Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
-                                       absl::Span<const int64> y_config,
-                                       absl::Span<const int64> output_config);
-
 // Supports two operand einsum notation like "ab,cb->ac".
 xla::XlaOp Einsum(
     xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
@@ -128,9 +120,6 @@
     xla::XlaOp x, absl::string_view einsum_config,
     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
 
-// Handles repeated indices within an operand by taking the tensor diagonal of
-// the input.
-xla::XlaOp EinsumDiagonal(XlaOp x, absl::Span<const int64> config);
 
 // Same as above but supporting numeric labels on dimensions. So "ab,cb->ac"
 // becomes:
diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc
index ebbf39e..628447c2 100644
--- a/tensorflow/compiler/xla/client/lib/matrix_test.cc
+++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc
@@ -233,12 +233,23 @@
   };
 
   std::vector<std::vector<string>> good_test_cases = {
-      {"ab", "bc", "ac"},           {"Bab", "Bbc", "Bac"},
-      {"ab", "cd", "dcba"},         {"abc", "abd", "cbd"},
-      {"...ab", "...bc", "...ac"},  {"a...bc", "...abd", "cbd..."},
-      {"...ab", "...bc", "ac"},     {"...b", "...bc", "...c"},
-      {"...abz", "...bc", "...ac"}, {"...ab", "...bcz", "...ac"},
-      {"abz", "bc", "ac"},          {"ab", "bcz", "ac"},
+      {"ab", "bc", "ac"},
+      {"Bab", "Bbc", "Bac"},
+      {"ab", "cd", "dcba"},
+      {"abc", "abd", "cbd"},
+      {"...ab", "...bc", "...ac"},
+      {"a...bc", "...abd", "cbd..."},
+      {"...ab", "...bc", "ac"},
+      {"...b", "...bc", "...c"},
+      {"...abz", "...bc", "...ac"},
+      {"...ab", "...bcz", "...ac"},
+      {"abz", "bc", "ac"},
+      {"ab", "bcz", "ac"},
+
+      {"a", "b", "c"},
+      {"...a", "...b", "...c"},
+      {"abb", "bcc", "ac"},
+      {"ab", "bc", "ad"},
   };
   for (auto test_case : good_test_cases) {
     auto parse_result_or_status =
@@ -249,9 +260,6 @@
     for (int i = 0; i < 3; ++i) {
       EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
     }
-    EXPECT_TRUE(ValidateEinsumNumericDimensions(
-                    parse_result[0], parse_result[1], parse_result[2])
-                    .ok());
   }
 
   std::vector<string> einsum_strings_that_fail_parsing = {
@@ -261,24 +269,6 @@
     auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
     EXPECT_FALSE(parse_result_or_status.status().ok());
   }
-  std::vector<std::vector<string>> einsum_strings_that_fail_numeric_validation =
-      {
-          {"a", "b", "c"},
-          {"...a", "...b", "...c"},
-          {"abb", "bcc", "ac"},
-          {"ab", "bc", "ad"},
-      };
-
-  for (auto test_case : einsum_strings_that_fail_numeric_validation) {
-    auto parse_result_or_status =
-        ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
-                          test_case[0].size(), test_case[1].size());
-    EXPECT_TRUE(parse_result_or_status.status().ok());
-    auto parse_result = parse_result_or_status.ValueOrDie();
-    EXPECT_FALSE(ValidateEinsumNumericDimensions(
-                     parse_result[0], parse_result[1], parse_result[2])
-                     .ok());
-  }
 }
 
 XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc
index a61f243..9752f84 100644
--- a/tensorflow/compiler/xla/client/lib/qr_test.cc
+++ b/tensorflow/compiler/xla/client/lib/qr_test.cc
@@ -27,12 +27,14 @@
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/tf32_utils.h"
 
 namespace {
 
 using QrTest = xla::ClientLibraryTestBase;
 
 XLA_TEST_F(QrTest, Simple) {
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   xla::XlaBuilder builder(TestName());
 
   xla::Array2D<float> a_vals({
@@ -61,6 +63,7 @@
 }
 
 XLA_TEST_F(QrTest, ZeroDiagonal) {
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   xla::XlaBuilder builder(TestName());
 
   xla::Array2D<float> a_vals({
@@ -88,6 +91,7 @@
 }
 
 XLA_TEST_F(QrTest, SimpleBatched) {
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   xla::XlaBuilder builder(TestName());
 
   xla::Array3D<float> a_vals({
diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc
index 750237c..5a7a701 100644
--- a/tensorflow/compiler/xla/client/lib/sorting.cc
+++ b/tensorflow/compiler/xla/client/lib/sorting.cc
@@ -16,6 +16,9 @@
 #include "tensorflow/compiler/xla/client/lib/sorting.h"
 
 #include "tensorflow/compiler/xla/client/lib/comparators.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/loops.h"
+#include "tensorflow/compiler/xla/client/lib/slicing.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -27,6 +30,19 @@
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
     int last_dim = input_shape.dimensions_size() - 1;
+    int64 last_dim_size = input_shape.dimensions(last_dim);
+    // TODO(b/165839365): tune these constants for better performance.
+    int64 kPerPartitionSize = 8192;        // 2^13
+    int64 kLastDimSizeThreshold = 524288;  // 2^19
+    int64 kMinNumPartitions = 8;
+    if ((k > 0) && (k < kPerPartitionSize) && (kPerPartitionSize / k > 2) &&
+        last_dim_size >= kLastDimSizeThreshold) {
+      int64 num_partitions =
+          CeilOfRatio(last_dim_size - k, kPerPartitionSize - k);
+      if (num_partitions >= kMinNumPartitions) {
+        return TopKWithPartitions(input, k, num_partitions);
+      }
+    }
 
     Shape iota_shape =
         ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
@@ -80,30 +96,35 @@
       }
     }
 
-    XlaOp values, indices;
-    for (int64 partition = 0; partition < num_partitions; partition++) {
-      std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
-      std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
-      std::vector<int64> strides(input_shape.dimensions_size(), 1);
-      start_indices[last_dim] = partition * per_partition_size;
-      limit_indices[last_dim] =
-          std::min((partition + 1) * per_partition_size, last_dim_size);
-      // Slice value and indices for this partition..
-      XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
+    auto topk_body_fn =
+        [&](XlaOp partition, absl::Span<const XlaOp> values_and_indices,
+            XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
+      auto values = values_and_indices[0];
+      auto indices = values_and_indices[1];
+      auto input = values_and_indices[2];
+      auto iota_s32 = values_and_indices[3];
+
+      // Slice value and indices for this partition.
+      XlaOp start = Mul(Add(partition, ConstantR0<int32>(builder, 1)),
+                        ConstantR0<int32>(builder, per_partition_size));
+      XlaOp sliced_input =
+          DynamicSliceInMinorDims(input, {start}, {per_partition_size});
       XlaOp sliced_indices =
-          Slice(iota_s32, start_indices, limit_indices, strides);
+          DynamicSliceInMinorDims(iota_s32, {start}, {per_partition_size});
       // Concat with previous results.
-      if (partition > 0) {
-        sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
-        sliced_indices =
-            ConcatInDim(builder, {indices, sliced_indices}, last_dim);
-      }
+      sliced_input = ConcatInDim(builder, {values, sliced_input}, last_dim);
+      sliced_indices =
+          ConcatInDim(builder, {indices, sliced_indices}, last_dim);
       // Sort this slice
       XlaOp sort_result =
           Sort({sliced_input, sliced_indices},
                CreateScalarGtComputation({input_shape.element_type(), S32},
                                          sliced_indices.builder()),
-               last_dim, /*is_stable=*/true);
+               last_dim, true);
+
+      std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
+      std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
+      std::vector<int64> strides(input_shape.dimensions_size(), 1);
       // Slice topk.
       start_indices[last_dim] = 0;
       limit_indices[last_dim] = k;
@@ -111,8 +132,42 @@
                      limit_indices, strides);
       indices = Slice(GetTupleElement(sort_result, 1), start_indices,
                       limit_indices, strides);
-    }
-    return Tuple(builder, {values, indices});
+      return std::vector<XlaOp>{values, indices, input, iota_s32};
+    };
+
+    // Get the values and indices for the first topk so that they can
+    // be passed to the while loop.
+    std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
+    std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
+    std::vector<int64> strides(input_shape.dimensions_size(), 1);
+    start_indices[last_dim] = 0;
+    limit_indices[last_dim] = per_partition_size;
+    // Slice value and indices for the first partition.
+    XlaOp sliced_input = Slice(input, start_indices, limit_indices, strides);
+    XlaOp sliced_indices =
+        Slice(iota_s32, start_indices, limit_indices, strides);
+    // Sort this slice
+    XlaOp sort_result =
+        Sort({sliced_input, sliced_indices},
+             CreateScalarGtComputation({input_shape.element_type(), S32},
+                                       sliced_indices.builder()),
+             last_dim, /*is_stable=*/true);
+
+    // Slice topk.
+    start_indices[last_dim] = 0;
+    limit_indices[last_dim] = k;
+    XlaOp values = Slice(GetTupleElement(sort_result, 0), start_indices,
+                         limit_indices, strides);
+    XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
+                          limit_indices, strides);
+
+    // Pass the result of the first TopK to the while loop and do
+    // num_partition - 1 iterations.
+    TF_ASSIGN_OR_RETURN(auto values_and_indices,
+                        ForEachIndex(num_partitions - 1, S32, topk_body_fn,
+                                     {values, indices, input, iota_s32},
+                                     "topk_with_partition", builder));
+    return Tuple(builder, {values_and_indices[0], values_and_indices[1]});
   });
 }
 
diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc
index e01f6fa..e820d5b 100644
--- a/tensorflow/compiler/xla/client/lib/sorting_test.cc
+++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc
@@ -118,6 +118,19 @@
   ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
 }
 
+XLA_TEST_F(SortingTest, DISABLED_TopKLargeInput) {
+  XlaBuilder builder(TestName());
+  Array<float> input({2, 1000000});
+  input.FillRandom(1.0f, 2.0f);
+  auto x =
+      CreateConstantFromLiteral(LiteralUtil::CreateFromArray(input), &builder);
+  Array2D<float> expected_array(2, 1000);
+  expected_array.Fill(2.0f);
+  xla::GetTupleElement(xla::TopK(x, 1000), 0);
+  ErrorSpec error_spec(10.0f, 10.0f);
+  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec);
+}
+
 XLA_TEST_F(SortingTest, TopK3From8Indices5Partitions) {
   XlaBuilder builder(TestName());
   auto x_rev =
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 33038dd..34d78f9 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -28,6 +28,7 @@
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
+#include "absl/types/span.h"
 #include "tensorflow/compiler/xla/client/sharding_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/comparison_util.h"
@@ -78,16 +79,13 @@
   return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto();
 }
 
-HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape,
-                                              bool pred) {
-  HloInstructionProto const_instr;
+void SetInstructionAsConstant(HloInstructionProto* instr, int64 id,
+                              const Shape& shape, bool pred) {
   Literal literal = LiteralUtil::CreateR0(pred);
   Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie();
-  *const_instr.mutable_shape() = shape.ToProto();
-  *const_instr.mutable_literal() = literal_broadcast.ToProto();
-  *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
-  const_instr.set_id(id);
-  return const_instr;
+  *instr->mutable_shape() = shape.ToProto();
+  *instr->mutable_literal() = literal_broadcast.ToProto();
+  *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
 }
 
 // Converts a HloComputation into ReducerOr with predicate types.
@@ -2971,27 +2969,12 @@
   *program_shape->mutable_result() =
       ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto();
 
-  std::set<int64> seen;
-  struct WorkItem {
-    explicit WorkItem(int64 handle, bool need_rewrite)
-        : handle(handle), need_rewrite(need_rewrite) {}
-    int64 handle;
-    // If need_rewrite is true, the instruction will be copied and rewrite into
-    // a pred instruction indicating if each value is dynamic. If need_rewrite
-    // is false, simply copy the instruction to the output graph.
-    // E.g.,
-    // For select(P, A, B), we need to rewrite A and B into predicates, but
-    // don't need to rewrite P.
-    bool need_rewrite;
-  };
-  std::queue<WorkItem> worklist;
-  worklist.push(WorkItem(root->id(), true));
-  entry.set_root_id(root->id());
   std::vector<HloComputationProto> called_computatons;
-  // Rewritre instruction with id "from" into the new graph.
-  // Returns more work items that need to finish.
-  auto rewrite_instruction =
-      [&](int64 from, bool need_rewrite) -> StatusOr<std::vector<WorkItem>> {
+  // Process instruction and copy it into the new graph. The new node in the new
+  // graph with have id set to `id`.
+  auto process_instruction = [&](const HloInstructionProto* instr_proto,
+                                 bool need_rewrite, int64 id,
+                                 absl::Span<int64 const> operand_ids) {
     // Rewrite the instruction with following rules:
     // - Unary ops: Convert into bitcast (identity) with type Pred.
     // - Binary ops: Convert into binary or.
@@ -3004,22 +2987,20 @@
     // - Constant: Convert to constant False.
     // - Other ops: Not supported.
     // Create the instruction for the new handle.
-    TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
-                        LookUpInstructionByHandle(from));
-
     TF_ASSIGN_OR_RETURN(HloOpcode opcode,
                         StringToHloOpcode(instr_proto->opcode()));
-    std::vector<WorkItem> operands_todo;
     auto* new_instr = entry.add_instructions();
     *new_instr = *instr_proto;
-    for (auto operand_id : new_instr->operand_ids()) {
-      operands_todo.emplace_back(operand_id, need_rewrite);
+    new_instr->set_id(id);
+    new_instr->mutable_operand_ids()->Clear();
+    for (auto operand_id : operand_ids) {
+      new_instr->mutable_operand_ids()->Add(operand_id);
     }
 
     if (!need_rewrite) {
       *new_instr->mutable_name() =
-          GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
-      return operands_todo;
+          GetFullName(instr_proto->opcode(), kNameSeparator, id);
+      return Status::OK();
     }
     *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape());
     Shape new_shape(new_instr->shape());
@@ -3074,10 +3055,8 @@
         *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr);
         break;
       case HloOpcode::kSelect:
-        operands_todo[0].need_rewrite = false;
         break;
       case HloOpcode::kGather:
-        operands_todo[1].need_rewrite = false;
         break;
       case HloOpcode::kReduce: {
         int64 reducer_id = new_instr->called_computation_ids(0);
@@ -3099,39 +3078,101 @@
         TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
                             LookUpInstructionByHandle(operand_handle));
 
-        *new_instr = CreateConstantInstruction(
-            from, new_shape,
+        SetInstructionAsConstant(
+            new_instr, id, new_shape,
             operand_proto->shape().is_dynamic_dimension(dimension));
-        operands_todo.clear();
         break;
       }
       case HloOpcode::kConstant:
-        *new_instr = CreateConstantInstruction(from, new_shape, false);
+        SetInstructionAsConstant(new_instr, id, new_shape, false);
         break;
       case HloOpcode::kParameter:
-        *new_instr = CreateConstantInstruction(from, new_shape, true);
+        SetInstructionAsConstant(new_instr, id, new_shape, true);
         break;
       default:
         return InvalidArgument("Dynamic inferencing %s is not supported",
                                instr_proto->DebugString());
     }
     *new_instr->mutable_name() =
-        GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id());
-    return operands_todo;
+        GetFullName(instr_proto->opcode(), kNameSeparator, id);
+    return Status::OK();
   };
 
+  struct WorkItem {
+    explicit WorkItem(int64 handle, bool need_rewrite)
+        : handle(handle), need_rewrite(need_rewrite), visited(false) {}
+    int64 handle;
+    // If need_rewrite is true, the instruction will be copied and rewrite into
+    // a pred instruction indicating if each value is dynamic. If need_rewrite
+    // is false, simply copy the instruction to the output graph.
+    // E.g.,
+    // For select(P, A, B), we need to rewrite A and B into predicates, but
+    // don't need to rewrite P.
+    bool need_rewrite;
+    // Used in dfs to remember the ids of processed operands of this item.
+    std::vector<int64> processed_operands;
+    // Whether this node been visited before or not.
+    bool visited;
+  };
+  // Only copy each pair of {handle, need_rewrite} once. Value is the id in the
+  // new graph.
+  absl::flat_hash_map<std::pair<int64, bool>, int64> seen;
+  // Monotonically increasing id to assign to new instructions.
+  int64 global_id = 0;
+  // The result id of the last rewritten item -- return value of last stack
+  // item.
+  int64 stacktop_id = -1;
+  std::vector<WorkItem> worklist;
+  worklist.push_back(WorkItem(root->id(), true));
   while (!worklist.empty()) {
-    WorkItem item = worklist.front();
-    worklist.pop();
-    if (!seen.insert(item.handle).second) {
+    WorkItem& item = worklist.back();
+    auto item_key = std::make_pair(item.handle, item.need_rewrite);
+    auto iter = seen.find(item_key);
+    // Already processed this item. Return previous results.
+    if (iter != seen.end()) {
+      stacktop_id = iter->second;
+      worklist.pop_back();
       continue;
     }
-    TF_ASSIGN_OR_RETURN(auto todos,
-                        rewrite_instruction(item.handle, item.need_rewrite));
-    for (WorkItem& todo : todos) {
-      worklist.push(todo);
+
+    int64 next_operand = item.processed_operands.size();
+    TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+                        LookUpInstructionByHandle(item.handle));
+    VLOG(3) << "Processing" << instr_proto->name();
+    if (!item.visited) {
+      item.visited = true;
+    } else {
+      // Record previous processed operand.
+      item.processed_operands.push_back(stacktop_id);
+      next_operand++;
     }
+    TF_ASSIGN_OR_RETURN(HloOpcode opcode,
+                        StringToHloOpcode(instr_proto->opcode()));
+    if (next_operand >= instr_proto->operand_ids_size() ||
+        opcode == HloOpcode::kGetDimensionSize) {
+      // No more operands to process, process self.
+      int64 new_id = ++global_id;
+      VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name();
+      TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite,
+                                             new_id, item.processed_operands));
+      stacktop_id = new_id;
+      seen[item_key] = stacktop_id;
+      worklist.pop_back();
+      continue;
+    }
+
+    WorkItem next_item(instr_proto->operand_ids(next_operand), true);
+    if (opcode == HloOpcode::kSelect && next_operand == 0) {
+      next_item.need_rewrite = false;
+    }
+    if (opcode == HloOpcode::kGather && next_operand == 1) {
+      next_item.need_rewrite = false;
+    }
+    // Push next operand into worklist.
+    worklist.push_back(next_item);
   }
+  TF_RET_CHECK(stacktop_id != -1);
+  entry.set_root_id(stacktop_id);
   absl::c_sort(*entry.mutable_instructions(),
                [](const HloInstructionProto& p1,
                   const HloInstructionProto& p2) { return p1.id() < p2.id(); });
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h
index 935f667..3971153 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h
@@ -695,7 +695,7 @@
   int32 launch_id = 0;
   // If non-null, an opaque context passed to an execution that may be used to
   // supply additional arguments to a derived class of PjRtExecutable.
-  std::unique_ptr<ExecuteContext> context;
+  const ExecuteContext* context = nullptr;
 };
 
 // Represents a compiled computation that can be executed given handles to
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 046fadb..e1eb93f 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -264,6 +264,7 @@
         "//tensorflow/core/platform:status",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/synchronization",
         "@com_google_absl//absl/types:optional",
         "@pybind11",
     ],
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index 96cf1e6..6f71152 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -26,11 +26,13 @@
 
 #include "tensorflow/compiler/xla/python/jax_jit.h"
 
+#include <exception>
 #include <memory>
 #include <stdexcept>
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/inlined_vector.h"
+#include "absl/synchronization/notification.h"
 #include "absl/types/optional.h"
 #include "pybind11/cast.h"
 #include "pybind11/numpy.h"
@@ -52,11 +54,14 @@
 namespace py = pybind11;
 
 // TODO(phawkins): Add support for Tracers.
-// TODO(jblespiau): Add support for donate_argnums.
 // TODO(jblespiau): Use absl Status.
 
 namespace {
 
+thread_local bool disable_jit;
+void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; }
+bool GetDisableJit() { return disable_jit; }
+
 // Describes the abstract shape and dtype of an argument.
 struct ArgSignature {
   // This is the XLA dtype of the object.
@@ -225,17 +230,23 @@
   // We need py::object to maintain the objects alive.
   std::vector<py::object> out_avals;
   std::vector<py::object> out_lazy_exprs;
+  // Ensures a single thread performs the compilation for a given executable.
+  //
+  // The first thread (holding the GIL) will create the CacheEntry associated to
+  // a signature and if the object has been insterted already, other threads
+  // will wait for the notification.
+  absl::Notification compilation_complete;
+  absl::optional<std::exception> compilation_error = absl::nullopt;
 };
 
 // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
 // bookkeeping of the different signatures used and the dispatch of calls to
-// the correct underlying `PyExecutable`.
+// the correct underlying `PyExecutable`. This class is thread-safe.
 class CompiledFunction {
  public:
-  CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted,
-                   bool jax_enable_x64, std::vector<int> static_argnums,
-                   std::shared_ptr<xla::PyClient> pyclient,
-                   xla::PjRtDevice* device);
+  CompiledFunction(py::function fun, py::function cache_miss_fun,
+                   py::function python_f_jitted, bool jax_enable_x64,
+                   bool jax_disable_jit, std::vector<int> static_argnums);
   ~CompiledFunction();
 
   // This function will:
@@ -246,10 +257,23 @@
   // (e) reconstruct the `PyTree`.
   py::object Call(py::args args, py::kwargs kwargs);
 
+  // This allows `inspect.signature(cpp_jitted_f)` from Python.
+  py::object __signature__() {
+    static const auto* inspect = new py::module(py::module::import("inspect"));
+    return inspect->attr("signature")(fun_);
+  }
+
  private:
   CacheEntry& GetCacheEntry(const py::args& args, const py::kwargs& kwargs,
-                            const CallSignature& signature);
+                            const CallSignature& signature,
+                            absl::optional<py::tuple> cache_miss_return);
+  CacheEntry& SetAndReturnCacheEntry(
+      const py::args& args, const py::kwargs& kwargs,
+      const CallSignature& signature,
+      absl::optional<py::tuple> cache_miss_return = absl::nullopt);
+  bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_; }
 
+  const py::function fun_;  // The Python function to jit.
   // The Python function in charge of returning a `xla::PyExecutable` from
   // the arguments passed to `jitted_f`.
   const py::function cache_miss_fun_;
@@ -260,6 +284,7 @@
 
   // The value of the Python flag when the object was created.
   const bool jax_enable_x64_;
+  const bool jax_disable_jit_;
 
   // We need to know the static arguments to remove them from the arguments
   // passed to the underlying PyExecutable. In sorted order.
@@ -267,22 +292,39 @@
   // We need a `unique_ptr` here to ensure value pointer stability.
   absl::flat_hash_map<CallSignature, std::unique_ptr<CacheEntry>> executables_;
 
-  const std::shared_ptr<xla::PyClient> pyclient_;
-  xla::PjRtDevice* const default_device_;
+  // As top-level functions are decorated with `jax.jit`, when
+  // `CompiledFunction` is being instantiated from Python, the clients are not
+  // yet available (done after GoogleInit). They will be during the first call
+  // to `Call`.
+  std::shared_ptr<xla::PyClient> pyclient_ = nullptr;
+  xla::PjRtDevice* default_device_ = nullptr;
+
+  // IMPORTANT: The GIL is not always held, because we call back to Python and
+  // Python will release the GIL.
+  // Thus, we protect the critical section modifying the `executables_` map
+  // and more generally the compilation with some `absl::Notification`.
+  // The first thread reaching such point will be responsible to create the
+  // notification for the executable and others will wait until notified.
+  // It's safe because the first thread will be holding the GIL while
+  // initializing the `Notification`.
+  //
+  // absl::optional<absl::Notification> is not supported
+  bool first_compilation_started_ = false;
+  absl::Notification first_compilation_complete_;
+  absl::optional<std::exception> first_compilation_error_ = absl::nullopt;
 };
 
-CompiledFunction::CompiledFunction(py::function cache_miss_fun,
+CompiledFunction::CompiledFunction(py::function fun,
+                                   py::function cache_miss_fun,
                                    py::function python_f_jitted,
-                                   bool jax_enable_x64,
-                                   std::vector<int> static_argnums,
-                                   std::shared_ptr<xla::PyClient> pyclient,
-                                   xla::PjRtDevice* device)
-    : cache_miss_fun_(std::move(cache_miss_fun)),
+                                   bool jax_enable_x64, bool jax_disable_jit,
+                                   std::vector<int> static_argnums)
+    : fun_(std::move(fun)),
+      cache_miss_fun_(std::move(cache_miss_fun)),
       python_f_jitted_(std::move(python_f_jitted)),
       jax_enable_x64_(jax_enable_x64),
-      static_argnums_(std::move(static_argnums)),
-      pyclient_(std::move(pyclient)),
-      default_device_(device) {
+      jax_disable_jit_(jax_disable_jit),
+      static_argnums_(std::move(static_argnums)) {
   std::sort(static_argnums_.begin(), static_argnums_.end());
 }
 
@@ -493,8 +535,21 @@
   xla::PjRtDevice* data_device = nullptr;
   for (py::handle arg : arguments.flat_dynamic_args) {
     if (py::isinstance(arg, device_array)) {
-      xla::PyBuffer* buffer =
-          py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
+      xla::PyBuffer* buffer;
+      try {
+        // This can fail, e.g. when device_buffer is a `DeviceConstant`.
+        buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
+      } catch (const py::cast_error& e) {
+        return InvalidArgument(
+            "%s",
+            absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
+                         "`device_buffer` field is of type ",
+                         py::cast<std::string>(
+                             arg.attr("device_buffer").get_type().str()),
+                         " while a `PyBuffer` was expected."
+
+                         ));
+      }
       xla::PjRtDevice* device = buffer->buffer()->device();
       if (data_device && (device != data_device)) {
         return InvalidArgument(
@@ -535,8 +590,6 @@
       py::array numpy_array = py::cast<py::array>(arg);
       // If jax_enable_x64 is not set, we need to coerce 32 bits types.
       // Note that this is calling back to Python!
-      // TODO(jblespiau): We can remove this complexity when we delete
-      // jax_enable_x64 mode.
       if (!jax_enable_x64) {
         const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
         if (to_dtype) {
@@ -577,49 +630,73 @@
 
 }  // namespace
 
-CacheEntry& CompiledFunction::GetCacheEntry(const py::args& args,
-                                            const py::kwargs& kwargs,
-                                            const CallSignature& signature) {
+CacheEntry& CompiledFunction::GetCacheEntry(
+    const py::args& args, const py::kwargs& kwargs,
+    const CallSignature& signature,
+    absl::optional<py::tuple> cache_miss_return) {
   auto found_iterator = executables_.find(signature);
   if (found_iterator != executables_.end()) {  // Cache hit!
+    if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
+      py::gil_scoped_release gil_release;
+      found_iterator->second->compilation_complete.WaitForNotification();
+      if (found_iterator->second->compilation_error) {
+        throw found_iterator->second->compilation_error.value();
+      }
+    }
     return *(found_iterator->second);
   }
-
+  return SetAndReturnCacheEntry(args, kwargs, signature, cache_miss_return);
+}
+CacheEntry& CompiledFunction::SetAndReturnCacheEntry(
+    const py::args& args, const py::kwargs& kwargs,
+    const CallSignature& signature,
+    absl::optional<py::tuple> cache_miss_return) {
   // We need to insert the element.
   auto result = executables_.emplace(signature, std::make_unique<CacheEntry>());
   auto it = result.first;
-
+  CacheEntry& cache_entry = *(it->second.get());
   // CallSignatures in the cache own their keyword argument reference.
   result.first->first.IncRef();
 
   // Cache miss? Call the Python cache miss function.
-  py::tuple executable_and_pytree = cache_miss_fun_(*args, **kwargs);
+  py::tuple executable_and_pytree;
+  if (cache_miss_return) {
+    executable_and_pytree = cache_miss_return.value();
+  } else {
+    try {
+      executable_and_pytree = cache_miss_fun_(*args, **kwargs);
+    } catch (const std::exception& e) {
+      cache_entry.compilation_error = e;
+      cache_entry.compilation_complete.Notify();
+      throw;
+    }
+  }
   if (executable_and_pytree.size() != 4) {
     throw std::runtime_error(
         "AssertionError: The cache miss function should return 4 "
         "arguments.");
   }
-  it->second->executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
+  cache_entry.executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
       std::move(executable_and_pytree[0]));
   int num_devices =
-      it->second->executable->pjrt_executable().local_devices().size();
+      cache_entry.executable->pjrt_executable().local_devices().size();
   if (num_devices != 1) {
     throw std::runtime_error(absl::StrCat(
         "Running on more than a single device is not currently supported."
         "The underlying PjRtExecutable has ",
         num_devices));
   }
-  it->second->device =
-      it->second->executable->pjrt_executable().local_devices()[0];
-  it->second->out_pytree_def = py::cast<PyTreeDef>(executable_and_pytree[1]);
+  cache_entry.device =
+      cache_entry.executable->pjrt_executable().local_devices()[0];
+  cache_entry.out_pytree_def = py::cast<PyTreeDef>(executable_and_pytree[1]);
 
   py::list shaped_arrays =
       py::reinterpret_borrow<py::object>(executable_and_pytree[2]);
   py::list lazy_expressions =
       py::reinterpret_borrow<py::object>(executable_and_pytree[3]);
 
-  it->second->out_avals.reserve(shaped_arrays.size());
-  it->second->out_lazy_exprs.reserve(lazy_expressions.size());
+  cache_entry.out_avals.reserve(shaped_arrays.size());
+  cache_entry.out_lazy_exprs.reserve(lazy_expressions.size());
 
   int num_outputs = shaped_arrays.size();
   for (int i = 0; i < num_outputs; ++i) {
@@ -628,17 +705,53 @@
     py::object lazy_expr =
         py::reinterpret_borrow<py::object>(lazy_expressions[i]);
 
-    it->second->out_avals.push_back(shaped_array);
-    it->second->out_lazy_exprs.push_back(lazy_expr);
+    cache_entry.out_avals.push_back(shaped_array);
+    cache_entry.out_lazy_exprs.push_back(lazy_expr);
   }
 
-  return *(it->second);
+  cache_entry.compilation_complete.Notify();
+  return cache_entry;
 }
 
 py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
+  if (JitIsDisabled()) {
+    return fun_(*args, **kwargs);
+  }
   ParsedArgumentsAsBuffers arguments;
   FlattenArguments(args, kwargs, static_argnums_, arguments);
 
+  // TODO(jblespiau): It would be preferable to have a single location for
+  // locking code.
+  absl::optional<py::tuple> cache_miss_result = absl::nullopt;
+  if (!default_device_) {
+    // TODO(jblespiau): This code will deadlock if a jitted function
+    // recursively calls itself.
+    if (first_compilation_started_) {
+      if (!first_compilation_complete_.HasBeenNotified()) {
+        py::gil_scoped_release gil_release;
+        first_compilation_complete_.WaitForNotification();
+        if (first_compilation_error_) {
+          throw first_compilation_error_.value();
+        }
+      }
+    } else {
+      first_compilation_started_ = true;
+      try {
+        cache_miss_result = cache_miss_fun_(*args, **kwargs);
+      } catch (const std::exception& e) {
+        first_compilation_error_ = e;
+        first_compilation_complete_.Notify();
+        throw;
+      }
+      auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
+          cache_miss_result.value()[0]);
+
+      pyclient_ = executable->client();
+      default_device_ = executable->LocalDevices()[0].contents;
+      first_compilation_complete_.Notify();
+    }
+  }
+
   // The C++ jit do not support Tracers arguments yet. The Python-based jit
   // function will be called if any of the dynamic arguments is unsupported.
   if (!ConvertArgsToBuffers(jax_enable_x64_, *pyclient_, default_device_,
@@ -647,7 +760,8 @@
     return python_f_jitted_(*args, **kwargs);
   }
 
-  CacheEntry& cache_entry = GetCacheEntry(args, kwargs, arguments.signature);
+  CacheEntry& cache_entry =
+      GetCacheEntry(args, kwargs, arguments.signature, cache_miss_result);
 
   std::vector<std::unique_ptr<xla::PyBuffer>> outputs =
       ValueOrThrow(cache_entry.executable->PjRtExecute(arguments.arg_buffers));
@@ -677,19 +791,21 @@
   py::class_<CompiledFunction, std::unique_ptr<CompiledFunction>> cfun(
       jitlib, "CompiledFunction");
   cfun.def("__call__", &CompiledFunction::Call);
+  cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__);
 
-  jitlib.def("jit",
-             [](py::function cache_miss_fun,
-                py::function fallback_on_unsupported_argument,
-                bool jax_enable_x64, std::vector<int> static_argnums,
-                xla::ClientAndPtr<xla::PjRtDevice> client_and_device)
-                 -> std::unique_ptr<CompiledFunction> {
-               return std::make_unique<CompiledFunction>(
-                   std::move(cache_miss_fun),
-                   std::move(fallback_on_unsupported_argument), jax_enable_x64,
-                   std::move(static_argnums), client_and_device.client,
-                   client_and_device.contents);
-             });
+  jitlib.def("set_disable_jit", &SetDisableJit);
+  jitlib.def("get_disable_jit", &GetDisableJit);
+  jitlib.def(
+      "jit",
+      [](py::function fun, py::function cache_miss_fun,
+         py::function fallback_on_unsupported_argument, bool jax_enable_x64,
+         bool jax_disable_jit,
+         std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> {
+        return std::make_unique<CompiledFunction>(
+            std::move(fun), std::move(cache_miss_fun),
+            std::move(fallback_on_unsupported_argument), jax_enable_x64,
+            jax_disable_jit, std::move(static_argnums));
+      });
 
   // Only for testing purposes
   jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64,
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index b3ba406..0660566 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -171,13 +171,13 @@
 void BuildProfilerSubmodule(py::module* m) {
   py::module profiler =
       m->def_submodule("profiler", "TensorFlow profiler integration");
-  py::class_<tensorflow::ProfilerServer,
-             std::unique_ptr<tensorflow::ProfilerServer>>
+  py::class_<tensorflow::profiler::ProfilerServer,
+             std::unique_ptr<tensorflow::profiler::ProfilerServer>>
       profiler_server_class(profiler, "ProfilerServer");
   profiler.def(
       "start_server",
-      [](int port) -> std::unique_ptr<tensorflow::ProfilerServer> {
-        auto server = absl::make_unique<tensorflow::ProfilerServer>();
+      [](int port) -> std::unique_ptr<tensorflow::profiler::ProfilerServer> {
+        auto server = absl::make_unique<tensorflow::profiler::ProfilerServer>();
         server->StartProfilerServer(port);
         return server;
       },
@@ -740,7 +740,7 @@
       .def(py::init([](const py::bytes& serialized_hlo_module_proto)
                         -> std::unique_ptr<XlaComputation> {
         HloModuleProto proto;
-        proto.ParseFromString(serialized_hlo_module_proto);
+        proto.ParseFromString(std::string(serialized_hlo_module_proto));
         return absl::make_unique<XlaComputation>(proto);
       }))
       .def("get_hlo_module", &GetHloModule)
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index dd16bd3..a1d6959 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2343,6 +2343,7 @@
         ":hlo_dce",
         ":hlo_pass",
         ":hlo_pass_pipeline",
+        ":hlo_verifier",
         ":tuple_simplifier",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index cb1bb19..3d49700 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1300,7 +1300,15 @@
       auto replacement =
           computation_->AddInstruction(concatenate->CloneWithNewOperands(
               concatenate->shape(), new_operands));
-      ReplaceInstructionIfSameShape(concatenate, replacement);
+
+      // Recurse to handle multiple disjoint sequence of inputs. The
+      // logic above merge only 1 sequential series of
+      // inputs. Otherwise, it can lead to the FixPass optimization
+      // hitting its threshold.
+      if (ReplaceInstructionIfSameShape(concatenate, replacement)) {
+        return HandleConcatenate(replacement);
+      }
+
       return Status::OK();
     }
   }
@@ -2500,6 +2508,20 @@
   if (ShapeUtil::IsZeroElementArray(operand_shape)) {
     return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
   }
+
+  // Gathering from a scalar operand is simply a broadcast of that scalar
+  if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
+    HloInstruction* new_operand = gather->mutable_operand(0);
+    if (operand_shape.rank()) {
+      TF_ASSIGN_OR_RETURN(new_operand,
+                          MakeReshapeHlo(ShapeUtil::MakeScalarShape(
+                                             operand_shape.element_type()),
+                                         new_operand));
+    }
+    HloInstruction* new_gather =
+        MakeBroadcastHlo(new_operand, {}, gather->shape());
+    return ReplaceInstruction(gather, new_gather);
+  }
   // If the operand of a gather is very small, it is easier to fuse a
   // sequence of selects.
   const Shape& index_shape = gather->operand(1)->shape();
@@ -3289,6 +3311,9 @@
   // padding with a pad with non-negative padding followed by a slice.
   bool all_zero = true;
   bool has_negative = false;
+  // Used to possibly split off the unchanged padding dimensions.
+  std::vector<int64> padding_dimensions;
+  int64 dimension_index = 0;
   for (auto& padding_dimension : pad->padding_config().dimensions()) {
     if (padding_dimension.edge_padding_low() < 0 ||
         padding_dimension.edge_padding_high() < 0) {
@@ -3297,12 +3322,93 @@
     if (padding_dimension.edge_padding_low() != 0 ||
         padding_dimension.edge_padding_high() != 0) {
       all_zero = false;
+      padding_dimensions.push_back(dimension_index);
+    } else if (padding_dimension.interior_padding()) {
+      padding_dimensions.push_back(dimension_index);
     }
+    dimension_index++;
   }
 
   if (all_zero) {
-    ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0));
-    return Status::OK();
+    if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) {
+      return Status::OK();
+    }
+  }
+
+  // The context of this optimization can be found at b/163617402
+  // It tries to capture the case of pad(broadcast(x)), where
+  // x->shape().dimensions(), or broadcast(x)->dimensions(), is
+  // a subset of the padded dimensions in pad->config(),
+  // and the padded dimensions in pad->config() is in turn a strict
+  // subset of broadcast->shape().dimensions(). The combined op can be
+  // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends
+  // x  with dimensions that need to be padded, and broadcast2 extends
+  // the result of padding to full dimensions.
+  // TODO(qyi): for future extensions: The condition for broadcast(x)
+  // ->dimensions() to be a subset of padded dimensions in pad->config()
+  // does not have to be strictly required, but it makes the calculation
+  // for optimization easier, so it is required by the current implementation.
+  // Only the second condition between the padded dimensions and the
+  // dimensions of the final shape have to be enforced for the optimization
+  // to make sense. If needed to remove the first constraint, the shape
+  // calculations across the implementation need to be re-adjusted.
+  auto pad_dims = padding_dimensions.size();
+  if (pad_dims < dimension_index &&
+      pad->operand(0)->opcode() == HloOpcode::kBroadcast &&
+      pad->operand(0)->user_count() == 1 &&
+      pad->operand(0)->operand(0)->shape().rank() <= pad_dims) {
+    // Check broadcast operand dimensions is a subset of pading_dimensions.
+    // If not, skip the optimization.
+    bool opt_is_valid = true;
+    std::vector<int64> broadcast_dimensions;
+    HloBroadcastInstruction* broadcast =
+        static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0));
+    for (auto broadcast_index : broadcast->dimensions()) {
+      bool found = false;
+      for (int i = 0; i < pad_dims; ++i) {
+        if (broadcast_index == padding_dimensions[i]) {
+          broadcast_dimensions.push_back(i);
+          found = true;
+          break;
+        }
+      }
+      if (!found) {
+        opt_is_valid = false;
+        break;
+      }
+    }
+    if (opt_is_valid) {
+      auto pad_shape = pad->shape();
+      auto broadcast_shape = broadcast->shape();
+      auto pad_shape1 = pad_shape;
+      auto broadcast_shape1 = broadcast_shape;
+      PaddingConfig pad_config;
+      for (int i = padding_dimensions.size() - 1; i >= 0; --i) {
+        int64 j = padding_dimensions[i];
+        while (--dimension_index > j) {
+          broadcast_shape1.DeleteDimension(dimension_index);
+          pad_shape1.DeleteDimension(dimension_index);
+        }
+      }
+      while (--dimension_index >= 0) {
+        broadcast_shape1.DeleteDimension(dimension_index);
+        pad_shape1.DeleteDimension(dimension_index);
+      }
+      for (auto dimension_to_pad : padding_dimensions) {
+        auto dimension = pad_config.add_dimensions();
+        *dimension = pad->padding_config().dimensions(dimension_to_pad);
+      }
+      *broadcast->mutable_shape() = broadcast_shape1;
+      *broadcast->mutable_dimensions() = broadcast_dimensions;
+      simplifier_->UpdateLayout(broadcast->mutable_shape());
+      auto pad2 =
+          computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1));
+      *pad2->mutable_padding_config() = pad_config;
+      simplifier_->UpdateLayout(pad2->mutable_shape());
+      auto broadcast2 = computation_->AddInstruction(
+          HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions));
+      return ReplaceInstruction(pad, broadcast2);
+    }
   }
 
   if (has_negative) {
@@ -3337,7 +3443,8 @@
         pad->shape(), nonzero_pad->mutable_shape()));
     simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
 
-    // Second, construct the slice instruction to perform the negative padding.
+    // Second, construct the slice instruction to perform the negative
+    // padding.
     std::vector<int64> start_indices;
     std::vector<int64> end_indices;
     std::vector<int64> strides;
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index c3e9061..c08fbd1 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2319,7 +2319,7 @@
 TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
   auto m = CreateNewVerifiedModule();
   Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
-  Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80});
+  Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 90});
   HloComputation::Builder builder(TestName());
   HloInstruction* param0 = builder.AddInstruction(
       HloInstruction::CreateParameter(0, r2f32, "param0"));
@@ -2366,10 +2366,15 @@
   HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
       ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
       /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
+  // Can merge 'slice7' and 'slice8'.
+  HloInstruction* slice8 = builder.AddInstruction(HloInstruction::CreateSlice(
+      ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 89},
+      /*limit_indices=*/{100, 99}, /*strides=*/{1, 1}));
 
   builder.AddInstruction(HloInstruction::CreateConcatenate(
       concat_shape,
-      {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1));
+      {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8},
+      1));
   auto computation = m->AddEntryComputation(builder.Build());
 
   AlgebraicSimplifier simplifier(default_options_);
@@ -2384,6 +2389,12 @@
       ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
                        ShapeUtil::MakeShape(F32, {50, 30})));
   EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
+
+  // The operand 6 should be  merge of 'slice7' and 'slice8', so its
+  // shape should have dimensions {50, 20}
+  EXPECT_TRUE(
+      ShapeUtil::Equal(computation->root_instruction()->operand(5)->shape(),
+                       ShapeUtil::MakeShape(F32, {50, 20})));
 }
 
 // Test that a simplification which changes layouts is not performed if layout
@@ -5647,6 +5658,30 @@
     DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
     ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
 
+TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
+  const char* hlo_string = R"(
+  HloModule repeat
+
+  ENTRY main {
+    o = f32[1,1] parameter(0)
+    i = s32[100,2] parameter(1)
+    ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1},
+                                  start_index_map={0,1},
+                                  index_vector_dim=1,
+                                  offset_dims={},
+                                  slice_sizes={1,1}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AlgebraicSimplifierOptions options;
+  AlgebraicSimplifier simplifier(options);
+  EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
+}
+
 TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
   const char* hlo_string = R"(
 HloModule module
@@ -6931,5 +6966,57 @@
               GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
 }
 
+TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) {
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+      c1 = pred[] constant(true)
+      b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={}
+      c3 = pred[] constant(false)
+      ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Broadcast(
+                  m::Pad(m::Broadcast(m::Constant()), m::Constant()))));
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) {
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+      c1 = pred[] constant(true)
+      b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={}
+      c3 = pred[] constant(false)
+      p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
+      ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::Broadcast(
+                  m::Pad(m::Broadcast(m::Constant()), m::Constant())))));
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) {
+  const char* kModuleStr = R"(
+    HloModule m
+    test {
+      c1 = pred[32] parameter(0)
+      b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2}
+      c3 = pred[] constant(false)
+      p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
+      ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  EXPECT_THAT(m->entry_computation()->root_instruction(),
+              GmockMatch(m::Tuple(m::Broadcast(
+                  m::Pad(m::Broadcast(m::Parameter()), m::Constant())))));
+}
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc
index cdda0ae..42caf20 100644
--- a/tensorflow/compiler/xla/service/conditional_code_motion.cc
+++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc
@@ -34,6 +34,7 @@
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -100,7 +101,7 @@
 // of reuses This is used as a placeholder only, assuming all
 // instructions can be fused to enable data reuses
 int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
-  VLOG(1) << "ConditionalCodeMotion: Add reuses carried by instr: "
+  VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
           << op->ToString() << "=>" << user->ToString() << "\n";
   switch (user->opcode()) {
     case HloOpcode::kGetTupleElement:
@@ -114,6 +115,8 @@
     case HloOpcode::kConstant:
     case HloOpcode::kGetTupleElement:
       return 0;
+    case HloOpcode::kConditional:
+      return 10;
     default:
       // Assume fusion will not happen anyway if user count > 1)
       if (op->user_count() > 1) {
@@ -432,7 +435,8 @@
   if (to_move_out.empty()) {
     return false;
   }
-  VLOG(1) << "number of boundaries to move out:" << to_move_out.size() << "\n";
+  VLOG(1) << "Modifying code--number of boundaries to move out:"
+          << to_move_out.size() << "\n";
   HloComputation* conditional_parent = conditional->parent();
   // save the old users before add new conditional user instructions
   std::vector<HloInstruction*> old_conditional_users = conditional->users();
@@ -441,7 +445,7 @@
   absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
   // Insert GetTupleElement before the instructions whose operands might still
   // be within the conditional.
-  VLOG(2) << "before opt:"
+  VLOG(1) << "before opt:"
           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
           << "\n";
   int64 op_index = 0;
@@ -470,16 +474,22 @@
   HloInstruction* old_root =
       conditional->branch_computation(0)->root_instruction();
   for (auto user_instr : old_conditional_users) {
+    VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
     CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
     auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
     int64 index = tuple_opd->tuple_index();
+    CHECK(old_root->operands().size() > index);
     HloInstruction* old_opd = old_root->operands()[index];
+    CHECK(ContainsKey(hoisted_instructions, old_opd));
     HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
     CHECK(old_opd != nullptr);
     CHECK(new_opd != nullptr);
+    VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n";
     TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
     TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
   }
+  VLOG(2) << "Done changing conditional users\n"
+          << conditional_parent->ToString() << "\n";
   // Create tuple element within each branch and set it as root.
   int64 branch_count = conditional->branch_count();
   for (int i = 0; i < branch_count; i++) {
@@ -487,9 +497,8 @@
     std::vector<HloInstruction*> elements;
     for (auto b1 : new_boundaries) {
       HloInstruction* op = b1.operands()[i];
-      VLOG(1) << "branch count=" << i << "\n";
       CHECK(op != nullptr);
-      VLOG(1) << "Adding to root " << i << " with " << op->ToString() << "\n";
+      VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
       elements.push_back(op);
     }
     HloInstruction* tuple =
@@ -507,7 +516,7 @@
       conditional->branch_computation(0)->root_instruction();
   *conditional->mutable_shape() = new_root->shape();
   //
-  VLOG(2) << "done moving instructions out of branches\n"
+  VLOG(1) << "done moving instructions out of branches\n"
           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
           << "\n";
   return true;
@@ -520,48 +529,89 @@
   if (to_move_in.empty()) {
     return false;
   }
-  VLOG(1) << "number of boundaries to move in:" << to_move_in.size() << "\n";
-  HloComputation* conditional_parent = conditional->parent();
-  VLOG(2) << "before opt:"
-          << conditional_parent->ToString(HloPrintOptions::Fingerprint())
+  VLOG(1) << "Modifying code---number of boundaries to move in:"
+          << to_move_in.size() << "\n";
+  VLOG(1) << "before opt:"
+          << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
           << "\n";
   // Mapping instructions to be moved to their new representations.
   absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
   int64 to_move_in_size = to_move_in.size();
   int64 branch_count = conditional->branch_count();
-  int64 op_index = conditional->shape().tuple_shapes_size();
-  // Map conditional to its old root, then create a new root instruction in each
-  // branch.
-  Boundary b(Boundary::Position::kInsideBranch);
+  // Number of old conditional entries still to be used outside.
+  // If conditional shape is not tuple, will create a tuple and use subscript
+  // 0 to save the old operand being used.
+  int64 op_index = conditional->shape().IsTuple()
+                       ? conditional->shape().tuple_shapes_size() - 1
+                       : 0;
+  HloGetTupleElementInstruction* tuple_use =
+      dynamic_cast<HloGetTupleElementInstruction*>(to_move_in[0].operands()[0]);
+  int64 use_index = (tuple_use != nullptr) ? tuple_use->tuple_index() : -1;
+  VLOG(2) << "Tuple use index = " << use_index << "\n";
+  // Use to map the tuple_use instruction to its operand;
+  Boundary b_opd_use(Boundary::Position::kInsideBranch);
+  Boundary b_old_root(Boundary::Position::kInsideBranch);
+  // Create a new root instruction in each branch.
   for (int i = 0; i < branch_count; i++) {
     auto computation = conditional->branch_computation(i);
     auto old_root = computation->root_instruction();
-    b.mutable_operands().push_back(old_root);
-    HloInstruction* new_root = nullptr;
+    b_old_root.mutable_operands().push_back(old_root);
+    std::vector<HloInstruction*> operands;
     if (old_root->opcode() == HloOpcode::kTuple) {
-      new_root = computation->AddInstruction(old_root->Clone());
-    } else {
-      std::vector<HloInstruction*> operands;
-      if (!old_root->shape().IsTuple()) {
-        operands.push_back(old_root);
-      } else {
-        const Shape& old_shape = old_root->shape();
-        for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) {
-          auto element =
-              computation->AddInstruction(HloInstruction::CreateGetTupleElement(
-                  old_shape.tuple_shapes(i), old_root, i));
-          operands.push_back(element);
+      // Use operands of old_root directly, so old_root can be removed later.
+      for (int i = 0; i < old_root->operand_count(); ++i) {
+        if (i != use_index) {
+          operands.push_back(old_root->operands()[i]);
+        } else {  // Map conditional use to the tuple operand.
+          b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
         }
       }
-      new_root =
-          computation->AddInstruction(HloInstruction::CreateTuple(operands));
+    } else if (old_root->shape().IsTuple()) {
+      // If old_root is not a kTuple but has tuple shape, elements within the
+      // tuple must be extracted first to be used by the new instructions.
+      const Shape& old_shape = old_root->shape();
+      for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) {
+        auto element =
+            computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+                old_shape.tuple_shapes(i), old_root, i));
+        if (i != use_index) {
+          operands.push_back(element);
+        } else {
+          b_opd_use.mutable_operands().push_back(element);
+        }
+      }
+    } else {
+      // If old_root is not a tuple and does not have tuple shape, use it
+      // to replace the conditional directly in the new computation.
+      b_opd_use.mutable_operands().push_back(conditional);
     }
+
+    HloInstruction* new_root =
+        computation->AddInstruction(HloInstruction::CreateTuple(operands));
     VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
-    computation->set_root_instruction(new_root);
+    computation->set_root_instruction(new_root,
+                                      /*accept_different_shape*/ true);
+    if (old_root->opcode() == HloOpcode::kTuple) {
+      TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
+    }
     VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
   }
-  hoisted_instructions[conditional] = b;
-  for (int64 i = 0; i < to_move_in_size; i++) {
+  // Update get tuple element index of the conditional.
+  if (use_index != -1) {
+    for (auto* user : conditional->users()) {
+      if (user->opcode() == HloOpcode::kGetTupleElement &&
+          user->tuple_index() > use_index) {
+        user->set_tuple_index(user->tuple_index() - 1);
+      }
+    }
+  }
+  hoisted_instructions[conditional] = b_old_root;
+  int64 cp_start = 0;
+  if (use_index >= 0) {
+    hoisted_instructions[tuple_use] = b_opd_use;
+    cp_start = 1;
+  }
+  for (int64 i = cp_start; i < to_move_in_size; i++) {
     Boundary b_to_move = to_move_in[i];
     HloInstruction* op = b_to_move.operands()[0];
     CHECK(op != nullptr);
@@ -591,12 +641,12 @@
     }
     if (to_be_used_outside) {
       // Modify uses of instructions outside of the conditionals
-      HloInstruction* gtr = conditional_parent->AddInstruction(
+      HloInstruction* gtr = conditional->parent()->AddInstruction(
           HloInstruction::CreateGetTupleElement(op->shape(), conditional,
                                                 op_index++));
       TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
-      if (conditional_parent->root_instruction() == op) {
-        conditional_parent->set_root_instruction(gtr);
+      if (conditional->parent()->root_instruction() == op) {
+        conditional->parent()->set_root_instruction(gtr);
       }
     }
   }
@@ -606,8 +656,8 @@
   HloInstruction* new_root =
       conditional->branch_computation(0)->root_instruction();
   *conditional->mutable_shape() = new_root->shape();
-  VLOG(2) << "Before removing instructions:" << conditional_parent->ToString()
-          << "\n";
+  VLOG(2) << "Before removing instructions:"
+          << conditional->parent()->ToString() << "\n";
   // Remove hoisted instructions from the branches.
   for (int64 i = to_move_in_size - 1; i >= 0; i--) {
     Boundary boundary_to_move_in = to_move_in[i];
@@ -616,10 +666,10 @@
     for (auto user : op->users()) {
       VLOG(2) << "Has User: " << user->ToString() << "\n";
     }
-    TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(op));
+    TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
   }
-  VLOG(2) << "Done moving instructions inside branches\n"
-          << conditional_parent->ToString(HloPrintOptions::Fingerprint())
+  VLOG(1) << "Done moving instructions inside branches\n"
+          << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
           << "\n";
   return true;
 }
@@ -631,6 +681,7 @@
   HloInstruction* conditional_;
   HloComputation* conditional_parent_;
   bool is_layout_sensitive_;
+  // Instructions that have been visited but are not going to be moved.
   absl::flat_hash_set<HloInstruction*> visited_;
 
  public:
@@ -639,7 +690,7 @@
       : conditional_(conditional),
         conditional_parent_(conditional->parent()),
         is_layout_sensitive_(is_layout_sensitive) {}
-  // Returns true if `instruction` is worth hoisting out.
+  // Returns true if `instruction` is worth hoisting.
   bool WorthHoisting(HloInstruction* instruction) {
     // This is needed for the "moving-in" transformation, to prevent the root
     // of the parent computation (which contains the conditional) to be moved
@@ -663,13 +714,14 @@
           case HloOpcode::kReshape:
             return true;
           default:
-            VLOG(1) << "Instruction is convert and its operand is not know to "
+            VLOG(2) << "Instruction is convert and its operand is not know to "
                        "be worth hoisting\n";
             return false;
         }
       case HloOpcode::kAllReduce:
       case HloOpcode::kAdd:
       case HloOpcode::kPower:
+      case HloOpcode::kCopy:
       case HloOpcode::kConstant:
       case HloOpcode::kSubtract:
       case HloOpcode::kMultiply:
@@ -680,24 +732,28 @@
       case HloOpcode::kGetTupleElement:
         return true;
       default:
-        VLOG(1) << "Instruction is not known to be worth hoisting\n";
+        VLOG(2) << "Instruction is not known to be worth hoisting\n";
         return false;
     }
   }
   int64 ReusesBeforeBoundary(HloInstruction* user) {
     int64 reuses = 0;
     for (auto op : user->operands()) {
+      // The operand must be an instruction that is not going to be moved (if
+      // user is inside the conditional); otherwise it must be the conditional
+      // itself and its user must be outside of the conditional.
+      if (!ContainsKey(visited_, op) && op != conditional_) {
+        continue;
+      }
       // Only consider single-user cases as reuseable.
-      if (ContainsKey(visited_, op) && op->user_count() == 1) {
+      if (user->opcode() == HloOpcode::kGetTupleElement &&
+          user->user_count() == 1) {
+        reuses += ReusesCarriedBy(op, user->users()[0]);
+      } else if (op->user_count() == 1) {
         reuses += ReusesCarriedBy(op, user);
-      } else if (op->opcode() == HloOpcode::kConditional &&
-                 user->opcode() == HloOpcode::kGetTupleElement) {
-        if (user->user_count() == 1) {
-          reuses += ReusesCarriedBy(op, user->users()[0]);
-        }
       }
     }
-    VLOG(1) << "Reuses before instruction " << user->ToString() << ":" << reuses
+    VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
             << "\n";
     return reuses;
   }
@@ -735,7 +791,7 @@
       } else if (ContainsKey(visited_, op)) {
         reuses += ReusesCarriedBy(user, op);
       }
-      VLOG(1) << "reuses after instruction " << user->ToString() << ":"
+      VLOG(2) << "reuses after instruction " << user->ToString() << ":"
               << reuses << "\n";
       return reuses;
     }
@@ -744,7 +800,8 @@
 
   int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
     int64 reuses_before = 0, reuses_after = 0;
-    if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch()) {
+    if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch() &&
+        boundaries[0].operands()[0]->opcode() == HloOpcode::kGetTupleElement) {
       // The only boundary of moving-in is the get_tuple_element op.
       return -1;
     }
@@ -754,16 +811,16 @@
         continue;
       }
       reuses_before += ReusesBeforeBoundary(op);
-      VLOG(1) << "Reuses before boundary so far: " << reuses_before << "\n";
+      VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
       reuses_after += ReusesAfterBoundary(op);
-      VLOG(1) << "Reuese after boundary so far : " << reuses_after << "\n";
+      VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
     }
     if (reuses_after == 0 && reuses_before == 0) {
       return -1;
     } else if (boundaries[0].IsInsideBranch()) {
       return reuses_after - reuses_before;
     } else {
-      return reuses_before - reuses_after;
+      return reuses_before - reuses_after - 1;
     }
   }
 
@@ -800,12 +857,12 @@
     visitor.AddToWorkList(boundary);
     while (visitor.HasNextBoundary()) {
       Boundary b = visitor.PopNextBoundary();
-      VLOG(1) << "visiting boundary " << b.ToString() << "\n";
+      VLOG(2) << "visiting boundary " << b.ToString() << "\n";
       if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
                                       b.operands(), is_layout_sensitive_)) &&
           WorthHoisting(b.operands()[0])) {
         connected_boundaries_.push_back(b);
-        VLOG(1) << "boundary can be moved\n";
+        VLOG(2) << "boundary can be moved\n";
         int64 operand_count = (b.IsInsideBranch())
                                   ? b.operands()[0]->operand_count()
                                   : b.operands()[0]->users().size();
@@ -829,7 +886,7 @@
           }
         }
       } else {
-        VLOG(1) << "boundary cannot be moved\n";
+        VLOG(2) << "boundary cannot be moved\n";
         visited_.insert(b.operands()[0]);
         new_boundaries_.push_back(b);
       }
@@ -876,7 +933,7 @@
   auto move_in_or_out = connect.BoundariesToMoveInOrOut(cur_boundary);
   if (!move_in_or_out.empty()) {
     auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
-    VLOG(1) << "benefit of moving in or out "
+    VLOG(2) << "benefit of moving in or out "
             << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
     if (benefit >= 0) {
       new_boundaries.clear();
@@ -899,9 +956,20 @@
   // Gather all the conditional ops in the module ahead of time, to avoid
   // potential complications of modifying the code that affecting traversal.
   std::vector<HloInstruction*> conditional_ops;
+  // Track how many times each branch computation is shared.
+  absl::flat_hash_map<HloComputation*, int> conditional_computations;
   for (auto* comp : module->MakeComputationPostOrder()) {
     for (auto* instr : comp->MakeInstructionPostOrder()) {
       if (instr->opcode() == HloOpcode::kConditional) {
+        int branch_count = instr->branch_count();
+        for (int i = 0; i < branch_count; ++i) {
+          HloComputation* branch_i = instr->branch_computation(i);
+          if (ContainsKey(conditional_computations, branch_i)) {
+            conditional_computations[branch_i]++;
+          } else {
+            conditional_computations[branch_i] = 0;
+          }
+        }
         conditional_ops.push_back(instr);
       }
     }
@@ -909,6 +977,17 @@
 
   bool changed = false;
   for (HloInstruction* conditional : conditional_ops) {
+    int branch_count = conditional->branch_count();
+    // check for shared conditional computations
+    bool conditional_is_shared = false;
+    for (int i = 0; i < branch_count; ++i) {
+      HloComputation* branch_i = conditional->branch_computation(i);
+      if (conditional_computations[branch_i] > 0) {
+        conditional_is_shared = true;
+        break;
+      }
+    }
+
     // Boundaries to move out or to move into the branches.
     std::vector<Boundary> to_move_out, to_move_in, new_boundaries;
     // The conditional is moved into a worklist as the seed (starting point).
@@ -926,6 +1005,33 @@
       Boundary boundary = visitor.PopNextBoundary();
       VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
       d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
+      if (d != Decision::kNoChange && conditional_is_shared) {
+        for (int i = 0; i < branch_count; ++i) {
+          HloComputation* branch_i = conditional->branch_computation(i);
+          if (conditional_computations[branch_i] > 0) {
+            // Cloning is absolutely needed if the computation is shared by
+            // different branches, but the cloning can be potentially avoided
+            // if the sharing is only among branches of the same conditional.
+            // If cloning these branches causes a problem due to space issues,
+            // a fix can pass a vector of unique branches to the actual
+            // transformations, as an alternative representation of the
+            // conditional branches to be modified. Right now we assume the
+            // overhead of cloning is minimal since later stages of the compiler
+            // inline all the computations anyway.
+            HloComputation* clone_i =
+                conditional->parent()->parent()->AddEmbeddedComputation(
+                    branch_i->Clone());
+            conditional->set_branch_computation(i, clone_i);
+            conditional_computations[branch_i]--;
+          }
+        }
+        to_move.clear();
+        next_boundary.clear();
+        VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
+                << "\n";
+        // Need to reanalyze the cloned code to generate correct result.
+        d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary);
+      }
       switch (d) {
         case Decision::kMoveOutOfBranch:
           VLOG(2) << "Decision is move out of branch\n";
@@ -961,22 +1067,14 @@
           MoveInstructionIn(conditional, to_move_in, new_boundaries));
       VLOG(2) << "moving in result:" << result << "\n";
       changed |= result;
-    }
-  }
-  // handling convert rematerialization/hoisting
-  if (!changed && pursue_full_conditional_code_motion_) {
-    std::vector<HloInstruction*> conditional_ops;
-    for (auto* comp : module->MakeComputationPostOrder()) {
-      for (auto* instr : comp->MakeInstructionPostOrder()) {
-        if (instr->opcode() == HloOpcode::kConditional) {
-          conditional_ops.push_back(instr);
-        }
-      }
-    }
-    for (HloInstruction* conditional_op : conditional_ops) {
+    } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
+      // Invoke special handling for convert rematerialization/hoisting
+      // We need to make sure no sharing is present in the branches because no
+      // cloning has been done by the earlier analysis.
+      // TOOD[b/165848866]: extend solution to handle cloning for special move.
       TF_ASSIGN_OR_RETURN(
           bool convert_result,
-          ConvertSpecialMove(conditional_op, is_layout_sensitive_));
+          ConvertSpecialMove(conditional, is_layout_sensitive_));
       changed |= convert_result;
     }
   }
@@ -986,6 +1084,7 @@
     subpipeline.AddPass<HloDCE>();
     subpipeline.AddPass<TupleSimplifier>();
     subpipeline.AddPass<HloDCE>();
+    subpipeline.AddPass<HloVerifier>(false, true);
     TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
     changed |= cleanup_changed;
   }
diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
index b0a6ba9..e5e3873 100644
--- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc
@@ -580,6 +580,214 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
 }
+
+TEST_F(ConditionalCodeMotionTest, MovePowInWithSharedBranch) {
+  absl::string_view hlo_string =
+      R"(
+HloModule RemoveIdenticalInstruction
+
+branch {
+  arg_tuple.1 = (f32[10]) parameter(0)
+  get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0
+  add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1)
+  ROOT tuple.3 = (f32[10]) tuple(add.1)
+}
+
+ENTRY main {
+  pred.1 = pred[] parameter(0)
+  tuple.1 = (f32[10]) parameter(1)
+  tuple.2 = (f32[10]) parameter(2)
+  conditional = (f32[10])
+    conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
+    false_computation=branch
+  get-first-index = f32[10] get-tuple-element(conditional), index=0
+  ROOT pow.1 = f32[10] power(get-first-index, get-first-index)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  ConditionalCodeMotion pass(true, true);
+  ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
+  const HloInstruction* conditional =
+      FindInstruction(module.get(), "conditional");
+  const HloComputation* on_true = conditional->branch_computation(0);
+  ASSERT_EQ(on_true->instruction_count(), 5);
+  const HloComputation* on_false = conditional->branch_computation(1);
+  ASSERT_EQ(on_false->instruction_count(), 5);
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
+}
+
+TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleRoot) {
+  absl::string_view hlo_string =
+      R"(
+HloModule RemoveIdenticalInstruction
+
+branch {
+  arg_tuple.1 = (f32[10]) parameter(0)
+  get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0
+  ROOT add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1)
+}
+
+ENTRY main {
+  pred.1 = pred[] parameter(0)
+  tuple.1 = (f32[10]) parameter(1)
+  tuple.2 = (f32[10]) parameter(2)
+  conditional = f32[10]
+    conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
+    false_computation=branch
+  ROOT pow.1 = f32[10] power(conditional, conditional)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  ConditionalCodeMotion pass(true, true);
+  ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
+  const HloInstruction* conditional =
+      FindInstruction(module.get(), "conditional");
+  const HloComputation* on_true = conditional->branch_computation(0);
+  ASSERT_EQ(on_true->instruction_count(), 5);
+  const HloComputation* on_false = conditional->branch_computation(1);
+  ASSERT_EQ(on_false->instruction_count(), 5);
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
+}
+
+TEST_F(ConditionalCodeMotionTest, MovePowInWithEmptyBranch) {
+  absl::string_view hlo_string =
+      R"(
+HloModule RemoveIdenticalInstruction
+
+branch1 {
+  arg_tuple.1 = (f32[10]) parameter(0)
+  get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0
+  add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1)
+  ROOT tuple.3 = (f32[10]) tuple(add.1)
+}
+
+branch2 {
+  ROOT arg_tuple.1 = (f32[10]) parameter(0)
+}
+
+ENTRY main {
+  pred.1 = pred[] parameter(0)
+  tuple.1 = (f32[10]) parameter(1)
+  tuple.2 = (f32[10]) parameter(2)
+  conditional = (f32[10])
+    conditional(pred.1, tuple.1, tuple.2), true_computation=branch1,
+    false_computation=branch2
+  get-first-index = f32[10] get-tuple-element(conditional), index=0
+  ROOT pow.1 = f32[10] power(get-first-index, get-first-index)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  ConditionalCodeMotion pass(true, true);
+  ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
+  const HloInstruction* conditional =
+      FindInstruction(module.get(), "conditional");
+  const HloComputation* on_true = conditional->branch_computation(0);
+  ASSERT_EQ(on_true->instruction_count(), 5);
+  const HloComputation* on_false = conditional->branch_computation(1);
+  ASSERT_EQ(on_false->instruction_count(), 4);
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
+}
+
+TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleParameter) {
+  absl::string_view hlo_string =
+      R"(
+HloModule RemoveIdenticalInstruction
+
+branch {
+  arg.1 = f32[10] parameter(0)
+  ROOT add.1 = f32[10] add(arg.1, arg.1)
+}
+
+ENTRY main {
+  pred.1 = pred[] parameter(0)
+  tuple.1 = f32[10] parameter(1)
+  tuple.2 = f32[10] parameter(2)
+  conditional = f32[10]
+    conditional(pred.1, tuple.1, tuple.2), true_computation=branch,
+    false_computation=branch
+  ROOT pow.1 = f32[10] power(conditional, conditional)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  ConditionalCodeMotion pass(true, true);
+  ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
+  const HloInstruction* conditional =
+      FindInstruction(module.get(), "conditional");
+  const HloComputation* on_true = conditional->branch_computation(0);
+  ASSERT_EQ(on_true->instruction_count(), 4);
+  const HloComputation* on_false = conditional->branch_computation(1);
+  ASSERT_EQ(on_false->instruction_count(), 4);
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional())));
+}
+
+TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) {
+  absl::string_view hlo_string =
+      R"(
+HloModule RemoveIdenticalInstruction
+
+branch1 {
+  arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0)
+  constant.1 = s32[] constant(4)
+  get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0
+  add.1 = s32[] add(get-tuple-element.1, constant.1)
+  get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1
+  slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2),
+   slice={[0:4:1], [0:3:1]}
+  constant.2 = f32[] constant(0.0)
+  ROOT tuple.1 = (f32[4,3]{0,1}, s32[],f32[]) tuple(slice.1, add.1, constant.2)
+}
+
+branch2 {
+  arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0)
+  get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0
+  copy.1 = s32[] copy(get-tuple-element.3)
+  get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1
+  copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4)
+  constant.2 = f32[] constant(0.0)
+  ROOT tuple.2 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.2, copy.1, constant.2)
+}
+
+ENTRY main {
+  pred.1 = pred[] parameter(0)
+  tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1)
+  tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2)
+  conditional = (f32[4,3]{0,1}, s32[], f32[])
+    conditional(pred.1, tuple.3, tuple.4), true_computation=branch1,
+    false_computation=branch2
+  get-zero-index = f32[4,3]{0,1} get-tuple-element(conditional), index=0
+  get-first-index = s32[] get-tuple-element(conditional), index=1
+  get-second-index = f32[] get-tuple-element(conditional), index=2
+  copy.3 = f32[4,3]{1,0} copy(get-zero-index)
+  ROOT tuple.5 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.3, get-first-index,
+                 get-second-index)
+}
+)";
+  auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  ConditionalCodeMotion pass(true, true);
+  ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
+  VLOG(1) << module->ToString();
+
+  const HloInstruction* conditional =
+      FindInstruction(module.get(), "conditional");
+  const HloComputation* on_true = conditional->branch_computation(0);
+  ASSERT_EQ(on_true->instruction_count(), 9);
+  const HloComputation* on_false = conditional->branch_computation(1);
+  ASSERT_EQ(on_false->instruction_count(), 8);
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root,
+              AllOf(op::Tuple(op::GetTupleElement(op::Conditional(), 2),
+                              op::GetTupleElement(op::Conditional(), 0),
+                              op::GetTupleElement(op::Conditional(), 1))));
+}
+
 }  // namespace conditional_opt
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index b622b71..4e25d66 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -130,11 +130,14 @@
         ":target_machine_features",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/types:span",
+        "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
-        "@llvm-project//mlir:ExecutionEngineUtils",
         "@llvm-project//mlir:LLVMDialect",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:VectorOps",
         "//tensorflow/compiler/xla/service:copy_insertion",
-        "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:dump",
         "//tensorflow/compiler/xla/service:topk_rewriter",
         "//tensorflow/compiler/xla/service:map_inliner",
@@ -198,7 +201,6 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
         "@llvm-project//llvm:Core",
-        "@llvm-project//llvm:MC",
         "@llvm-project//llvm:Object",
         "@llvm-project//llvm:Support",
         "@llvm-project//llvm:Target",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 7b72d7a..e6c72e6 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -42,7 +42,12 @@
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"  // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/Dialect/Vector/VectorOps.h"  // from @llvm-project
 #include "mlir/InitAllDialects.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
 #include "tensorflow/compiler/xla/literal.h"
@@ -121,6 +126,21 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/dynamic_annotations.h"
 
+namespace {
+
+// We need to explicitly load all the dialects we will involved in emitting the
+// IR. This is only needed because of how MLIR is bolted into XLA and does not
+// make use of the MLIR infrastructure (like using a proper pass pipeline).
+// Hopefully this will all go away at some point in favor of a better
+// integration.
+void LoadMLIRDialects(mlir::MLIRContext& context) {
+  context.loadDialect<mlir::linalg::LinalgDialect, mlir::scf::SCFDialect,
+                      mlir::vector::VectorDialect, mlir::StandardOpsDialect,
+                      mlir::AffineDialect>();
+}
+
+}  // namespace
+
 namespace xla {
 namespace cpu {
 using BufferInfo = cpu_function_runtime::BufferInfo;
@@ -164,8 +184,6 @@
   // Initialize LLVM's MC layer for the native target.
   llvm::InitializeNativeTarget();
   llvm::InitializeNativeTargetAsmPrinter();
-
-  mlir::registerAllDialects();
 }
 
 namespace {
@@ -622,6 +640,7 @@
 
   // Compile must be thread-safe so create a new LLVM context for the module.
   mlir::MLIRContext mlir_context;
+  LoadMLIRDialects(mlir_context);
   llvm::LLVMContext llvm_context;
   auto llvm_module =
       absl::make_unique<llvm::Module>("__compute_module", llvm_context);
@@ -833,6 +852,7 @@
 
   // Compile must be thread-safe so create a new LLVM context for the module.
   mlir::MLIRContext mlir_context;
+  LoadMLIRDialects(mlir_context);
   llvm::LLVMContext llvm_context;
   llvm::Module llvm_module("__compute_module", llvm_context);
   llvm_module.setDataLayout(target_machine->createDataLayout());
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index f1a0b0a..cbed232 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -276,7 +276,7 @@
     llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
     if (llvm::isa<llvm::VectorType>(type)) {
       return llvm::ConstantVector::getSplat(
-          llvm::ElementCount(vector_size(), /*Scalable=*/false), scalar_value);
+          llvm::ElementCount::getFixed(vector_size()), scalar_value);
     }
     return scalar_value;
   }
diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc
index 75d3929..4730d9c 100644
--- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc
+++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc
@@ -96,16 +96,7 @@
     const HloInstruction* instruction) {
   int64 total = 0;
   for (const auto* user : indexing_users_[instruction]) {
-    int64 weight = 1;
-    // Concatenate is special: the index differs for each operand, so
-    // in the worst case we have to deal with as many index values as
-    // the number of operands of Concatenate. By considering the worst
-    // case, we are more conservative than necessary regarding
-    // counting the index usage.
-    if (user->opcode() == HloOpcode::kConcatenate) {
-      weight = user->operand_count();
-    }
-    total += index_usage_count_.at(user) * weight;
+    total += index_usage_count_.at(user);
   }
   CHECK(index_usage_count_.emplace(instruction, total).second);
   total_emitted_instructions_ += total;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d1d0827..ce761d8 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -254,6 +254,7 @@
         ":target_util",
         ":thunk",
         ":thunk_emitter",
+        "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:lhlo",
         "//tensorflow/compiler/mlir/xla:hlo_utils",
         "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 60e4cb8..a499dc7 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -230,18 +230,15 @@
   // This is done to avoid the duplication of expensive instructions, which
   // would occur if 'fusion' were merged into multiple users.
   //
-  // If 'fusion' has just one user, then an earlier fusion pass chose not to
-  // fuse this producer/consumer pair (likely because of expensive instruction
-  // re-use by the consumer), and so we honor that choice here as well.
-  //
-  // Moreover, if we are going to save a "lot" in memory bandwidth then we
+  // However, if we are going to save a "lot" in memory bandwidth then we
   // ignore how expensive the fusion instructions are.  The heuristic used to
   // determine "a lot" is the following: merging must reduce memory traffic by a
   // factor of 0.3, and the amount of memory accessed must not be entirely
   // trivial (above 1K).  This likely has room for improvement in the future.
 
   bool allow_expensive_ops =
-      merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024;
+      fusion->user_count() == 1 ||
+      (merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024);
 
   if (!allow_expensive_ops &&
       absl::c_any_of(fusion->fused_instructions(),
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index 4289115..cc4894f 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -398,6 +398,29 @@
   EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
 }
 
+TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) {
+  auto module = ParseAndReturnVerifiedModule(R"(
+    HloModule m
+
+    %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+      %p = f32[1024,1024,1024] parameter(0)
+      ROOT %t = f32[1024,1024,1024] tanh(%p)
+    }
+
+    %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
+      %p = f32[1024,1024,1024] parameter(0)
+      ROOT %t = f32[1024,1024,1024] add(%p, %p)
+    }
+
+    ENTRY entry {
+      p0 = f32[1024,1024,1024] parameter(0)
+      f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
+      ROOT f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c
+    })")
+                    .ValueOrDie();
+  EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
index bb4184f..19aa8d0 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -347,8 +347,13 @@
 // This limit is also often good for performance.  In a fusion with many
 // operands, each GPU thread likely has to do a lot of work, and so possibly
 // uses a lot of registers, thus limiting occupancy.
+//
+// If the fusion is a producer/consumer fusion and instr1 is the
+// consumer and instr2 is the producer, set is_consumer_producer_fusion
+// to true to enable more fusion.
 bool FusionWouldBeTooLarge(const HloInstruction& instr1,
-                           const HloInstruction& instr2) {
+                           const HloInstruction& instr2,
+                           bool is_consumer_producer_fusion) {
   if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) >
       kSharedMemoryBudgetInBytes) {
     VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString()
@@ -404,6 +409,17 @@
   // producer -> consumer relationship.
   operands.erase(&instr1);
   operands.erase(&instr2);
+
+  // If we generate the same numbers of inputs and outputs as
+  // before, it won't be bigger after fusion. So accept the fusion.
+  // As this is a consumer_producer fusion, this does not change the
+  // consumer numbers of output. So no need to check it.
+  if (is_consumer_producer_fusion &&
+      operands.size() <= instr1.operands().size()) {
+    return false;
+  }
+
+  // Does the new fusion have more operands and outputs than the max?
   return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
index e2a42ec..7dfc59a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
@@ -64,8 +64,12 @@
 // Determines whether the combination of `instr1` and `instr2` into a (possibly
 // multi-output) fusion would be "too large" -- i.e., have more operands and
 // outputs than is allowed or occupy too much shared memory.
+// If the fusion is a producer/consumer fusion and instr1 is the
+// consumer and instr2 is the producer, set consumer_producer_fusion
+// to true to enable more fusion.
 bool FusionWouldBeTooLarge(const HloInstruction& instr1,
-                           const HloInstruction& instr2);
+                           const HloInstruction& instr2,
+                           bool is_consumer_producer_fusion = false);
 
 // Check if fusing producer and consumer will generate a nested loop, e.g. both
 // producer and consumer are `reduce-window` HLO instructions.
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index b994ead..d85150b 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -60,6 +60,7 @@
 
   // Output fusions are not currently supported on GPUs.
   if (producer->opcode() == HloOpcode::kFusion) {
+    VLOG(4) << "Producer " << producer->name() << " is a fusion op";
     return false;
   }
   // Cost condition: not fuse (simple, expensive producers) and (consumers who
@@ -67,11 +68,15 @@
   if (producer->opcode() != HloOpcode::kFusion &&
       consumer->ReusesOperandElements(operand_index) &&
       is_expensive(*producer)) {
+    VLOG(4) << "Do not fuse simple, expensive producer " << producer->name()
+            << " and consumer which reuses operand elements.";
     return false;
   }
 
   if (!IsProducerConsumerFusible(*producer, *consumer) ||
       !InstructionFusion::ShouldFuse(consumer, operand_index)) {
+    VLOG(4) << "Producer " << producer->name()
+            << " is not fusible or should not be fused.";
     return false;
   }
   return true;
@@ -87,7 +92,8 @@
   auto producer = consumer->operand(operand_index);
 
   // The following checks are potentially expensive.
-  if (FusionWouldBeTooLarge(*consumer, *producer)) {
+  if (FusionWouldBeTooLarge(*consumer, *producer,
+                            /*is_consumer_producer_fusion=*/true)) {
     VLOG(5) << "Fusion of (" << producer->ToString() << ") into ("
             << consumer->ToString() << ") would be too large";
     return false;
@@ -107,8 +113,13 @@
     fusion_node_evaluations_.emplace(consumer,
                                      FusionNodeIndexingEvaluation(consumer));
   }
-  return !fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh(
-      producer);
+  if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh(
+          producer)) {
+    VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name()
+            << " would result in overly large code duplication.";
+    return false;
+  }
+  return true;
 }
 
 bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
index 7d5a8d0..2a493fe 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
@@ -17,7 +17,10 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_
 
 #include "llvm/IR/Module.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
@@ -44,7 +47,11 @@
         cuda_compute_capability_(cuda_compute_capability),
         profile_index_map_(profile_index_map),
         mlir_context_(mlir_context),
-        llvm_module_(llvm_module) {}
+        llvm_module_(llvm_module) {
+    mlir_context_
+        ->loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
+                      mlir::StandardOpsDialect>();
+  }
   // Disallow copy and assign.
   IrEmitterContext(const IrEmitterContext&) = delete;
   IrEmitterContext& operator=(const IrEmitterContext&) = delete;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index b9146dd..c295568 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -404,11 +404,10 @@
   // the process. `scatter` may be fused, scatter indices are taken from
   // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
   // expected to have the operand values in it already. If unique_indices
-  // is false, we will use an atomic update. Using false for unique_indices
-  // is safe only when it is guaranteed that there are no duplicate
-  // indices.
-  // When using unique_indices=true, it is the caller's responsibility to
-  // ensure there is no overlap.
+  // is false, we will use an atomic update. Using true for unique_indices
+  // behaves properly only when it is guaranteed that the indices to be
+  // updated do not overlap. The caller is responsible for ensuring this is
+  // the case.
   Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
                      const llvm_ir::ElementGenerator& scatter_indices_gen,
                      const llvm_ir::ElementGenerator& updates_gen);
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 1228a1b..04af67a 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -62,8 +62,10 @@
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/random.h"
 #include "tensorflow/core/platform/tracing.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/util/env_var.h"
 
 namespace xla {
 namespace gpu {
@@ -86,14 +88,21 @@
   int sm_version = 30;
   // If the current compute capability isn't known, fallback to the
   // most recent version before it.
-  for (int v : {75, 72, 70, 62, 61, 60, 53, 52, 50, 37, 35, 32, 30}) {
+  int supported_versions[] = {75, 72, 70, 62, 61, 60, 53,
+                              52, 50, 37, 35, 32, 30};
+  for (int v : supported_versions) {
     if (v <= compute_capability_version) {
       sm_version = v;
       break;
     }
   }
 
-  if (sm_version != compute_capability_version) {
+  // If the current CC isn't supported by LLVM and it is newer then
+  // the max supported LLVM version, do not warn about it. The end
+  // user can't do anything about this. PTX compiled for SM75 will
+  // run on SM80 too.
+  if (sm_version != compute_capability_version &&
+      compute_capability_version < supported_versions[0]) {
     LOG(WARNING) << "Unknown compute capability (" << compute_capability.first
                  << ", " << compute_capability.second << ") ."
                  << "Defaulting to telling LLVM that we're compiling for sm_"
@@ -570,6 +579,60 @@
   return result;
 }
 
+struct HsacoCacheEntry {
+  uint64 hash;
+  std::string ir;
+  int gfx;
+  std::vector<uint8> hsaco;
+};
+
+struct HsacoCache {
+ protected:
+  std::vector<HsacoCacheEntry> cache;
+  std::mutex m_mutex;
+  int request_count = 0;
+  int hit_count = 0;
+
+ public:
+  static bool Find(const std::string& ir, uint64_t& hash, int gfx,
+                   std::vector<uint8>& hsaco);
+  static void Add(const std::string& ir, uint64_t hash, int gfx,
+                  const std::vector<uint8>& hsaco);
+};
+
+static HsacoCache g_hsacoCache;
+
+bool HsacoCache::Find(const std::string& ir, uint64_t& hash, int gfx,
+                      std::vector<uint8>& hsaco) {
+  std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
+  hash = std::hash<std::string>{}(ir);
+  bool hit = false;
+  for (auto& x : g_hsacoCache.cache) {
+    if (x.hash != hash) continue;
+    if (x.gfx != gfx) continue;
+    if (x.ir != ir) continue;
+    hsaco = x.hsaco;
+    hit = true;
+    break;
+  }
+  g_hsacoCache.request_count++;
+  if (hit) g_hsacoCache.hit_count++;
+  if (!(g_hsacoCache.request_count % 50))
+    VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, "
+            << g_hsacoCache.hit_count << " hits";
+  return hit;
+}
+
+void HsacoCache::Add(const std::string& ir, uint64_t hash, int gfx,
+                     const std::vector<uint8>& hsaco) {
+  std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
+  g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1);
+  g_hsacoCache.cache.back().ir = ir;
+  g_hsacoCache.cache.back().hash = hash;
+  g_hsacoCache.cache.back().gfx = gfx;
+  g_hsacoCache.cache.back().hsaco = hsaco;
+}
+
 // Emits the given module to HSA Code Object. target_machine is an initialized
 // TargetMachine for the AMDGPU target.
 StatusOr<std::vector<uint8>> EmitModuleToHsaco(
@@ -584,18 +647,29 @@
   std::string tempdir_name = tempdir_vector.front();
   VLOG(1) << "Compile-time artifacts located at: " << tempdir_name;
 
+  bool keep_tempfiles = false;
+  TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_ROCM_KEEP_XLA_TEMPFILES",
+                                             /*default_val=*/false,
+                                             &keep_tempfiles));
   // Prepare filenames for all stages of compilation:
   // IR, binary ISA, and HSACO.
-  std::string ir_filename = absl::StrCat(module->getModuleIdentifier(), ".ll");
+  std::string random_number = std::to_string(tensorflow::random::New64());
+  std::string ir_filename =
+      absl::StrCat(module->getModuleIdentifier(), random_number + ".ll");
   std::string ir_path = tensorflow::io::JoinPath(tempdir_name, ir_filename);
 
+  std::string ir_opt_filename =
+      absl::StrCat(module->getModuleIdentifier(), random_number + "_opt.ll");
+  std::string ir_opt_path =
+      tensorflow::io::JoinPath(tempdir_name, ir_opt_filename);
+
   std::string isabin_filename =
-      absl::StrCat(module->getModuleIdentifier(), ".o");
+      absl::StrCat(module->getModuleIdentifier(), random_number + ".o");
   std::string isabin_path =
       tensorflow::io::JoinPath(tempdir_name, isabin_filename);
 
   std::string hsaco_filename =
-      absl::StrCat(module->getModuleIdentifier(), ".hsaco");
+      absl::StrCat(module->getModuleIdentifier(), random_number + ".hsaco");
   std::string hsaco_path =
       tensorflow::io::JoinPath(tempdir_name, hsaco_filename);
 
@@ -613,7 +687,7 @@
   std::string module_id = module->getModuleIdentifier();
   IrDumpingPassManager codegen_passes(
       ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
-                               "-amdgpu.dummy"),
+                               random_number + "-amdgpu.dummy"),
       "", false);
   codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
       llvm::Triple(module->getTargetTriple())));
@@ -627,6 +701,12 @@
   codegen_passes.run(*module);
   isabin_fs->flush();
 
+  if (keep_tempfiles) {
+    std::unique_ptr<llvm::raw_fd_ostream> ir_fs(
+        new llvm::raw_fd_ostream(ir_opt_path, ec, llvm::sys::fs::F_None));
+    module->print(*ir_fs, nullptr);
+    ir_fs->flush();
+  }
   // Locate lld.
   // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after
   // ROCm-Device-Libs PR.
@@ -652,9 +732,9 @@
   int lld_result =
       llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args),
                                 llvm::None, {}, 0, 0, &error_message);
-
   if (lld_result) {
-    return xla::InternalError("ld.lld execute fail: %s", error_message);
+    return xla::InternalError("ld.lld execute fail: %s, error code %d",
+                              error_message, lld_result);
   }
 
   // Read HSACO.
@@ -664,6 +744,12 @@
   std::vector<uint8> hsaco(hsaco_file_size);
   hsaco_file.seekg(0, std::ios::beg);
   hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
+  hsaco_file.close();
+  if (!keep_tempfiles) {
+    remove(ir_path.c_str());
+    remove(isabin_path.c_str());
+    remove(hsaco_path.c_str());
+  }
   return hsaco;
 }
 
@@ -728,6 +814,20 @@
 
   std::vector<uint8> hsaco;
   std::unique_ptr<llvm::TargetMachine> target_machine;
+  std::string str;
+  llvm::raw_string_ostream stream(str);
+  stream << *module;
+  // Delete the first two lines, since they usually vary even when the rest of
+  // the code is the same (but verify that they are what we expect).
+  if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") {
+    auto pos = str.find("\n");
+    if (pos != std::string::npos) str = str.substr(pos + 1);
+  }
+  if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") {
+    auto pos = str.find("\n");
+    if (pos != std::string::npos) str = str.substr(pos + 1);
+  }
+  str += hlo_module_config.compilation_cache_key();
   {
     tensorflow::profiler::TraceMe activity(
         [&] { return absl::StrCat("Compiling IR", module->getName().str()); },
@@ -739,6 +839,21 @@
       return xla::InternalError(
           "Incompatible AMD GCN ISA version was specified.");
     }
+    uint64_t hash;
+    if (HsacoCache::Find(str, hash, *amdgpu_version, hsaco)) {
+      VLOG(1) << "HSACO cache hit";
+      return hsaco;
+    }
+    VLOG(1) << "HSACO cache miss";
+    bool dump_lls = false;
+    if (dump_lls) {
+      static int hsaco_count = 0;
+      std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll";
+      hsaco_count++;
+      std::ofstream ofs(name);
+      ofs << str;
+      ofs.close();
+    }
 
     llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz");
     // Construct LLVM TargetMachine for AMDGPU.
@@ -754,6 +869,7 @@
 
     // Lower optimized LLVM module to HSA code object.
     TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get()));
+    HsacoCache::Add(str, hash, *amdgpu_version, hsaco);
   }
   return hsaco;
 }
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 809b277..f7d9e93 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -375,6 +375,8 @@
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla/service:hlo_module_config",
         "//tensorflow/compiler/xla/service:hlo_parser",
+        "//tensorflow/compiler/xla/service/gpu:gpu_fusible",
+        "//tensorflow/compiler/xla/service/gpu:instruction_fusion",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc
index 674b436..811705d 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc
@@ -15,6 +15,8 @@
 
 #include <utility>
 
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
+#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
@@ -54,6 +56,37 @@
       )");
 }
 
+// Check that we limit the number of operands to fusions we create.
+TEST_F(GpuFusionTest, FusedBiggerThenThresholdButDoNotChangeTheFusionl) {
+  constexpr int64 kNumParams = kMaxOperandsAndOutputsPerFusion + 1;
+
+  // Compute
+  //   p0 + p1 + p2 + ... + pn,
+  // Use so many parameters that they do not fit into one fusion.
+  auto module = CreateNewVerifiedModule();
+  HloComputation::Builder b(TestName());
+  Shape input_shape = ShapeUtil::MakeShape(F32, {10, 100});
+  Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 2});
+  Shape concat_shape = ShapeUtil::MakeShape(F32, {10, 2 * kNumParams});
+  HloInstruction* input =
+      b.AddInstruction(HloInstruction::CreateParameter(0, input_shape, "p"));
+
+  std::vector<HloInstruction*> slice_params;
+  for (int64 i = 0; i < kNumParams; ++i) {
+    slice_params.push_back(b.AddInstruction(HloInstruction::CreateSlice(
+        slice_shape, input, {0, 0}, {10, 2}, {1, 1})));
+  }
+  b.AddInstruction(
+      HloInstruction::CreateConcatenate(concat_shape, slice_params, 1));
+  module->AddEntryComputation(b.Build());
+  EXPECT_TRUE(GpuInstructionFusion(false).Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(module->entry_computation()->root_instruction()->opcode() ==
+              HloOpcode::kFusion);
+  for (HloInstruction* instr : module->entry_computation()->instructions()) {
+    EXPECT_TRUE(instr->opcode() != HloOpcode::kSlice);
+  }
+}
+
 }  // namespace
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 17a7b18..c3a7b3a 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -35,7 +35,7 @@
 option cc_enable_arenas = true;
 
 // Serialization of HloInstruction.
-// Next ID: 73
+// Next ID: 74
 message HloInstructionProto {
   reserved 10;
   reserved "parameter_name";
@@ -251,6 +251,9 @@
 
   // The comparison type used for kCompare.
   string comparison_type = 72;
+
+  // Specifies if this is a cross-program-prefetch, used by kCopyStart.
+  bool is_cross_program_prefetch = 73;
 }
 
 // Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 1bbbb24..551ffb5 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1229,10 +1229,10 @@
   auto builder = HloComputation::Builder(TestName());
   auto constant = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
-  auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
+  auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
       ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
                                  ShapeUtil::MakeShape(U32, {})}),
-      HloOpcode::kCopyStart, constant));
+      constant));
   auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
       constant->shape(), HloOpcode::kCopyDone, copy_start));
   module_->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 9a4049c..bb01fdd 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -167,6 +167,11 @@
                               absl::Span<const int64>(fft_length));
       break;
     }
+    case HloOpcode::kCopyStart: {
+      instruction = CreateCopyStart(shape, operands(0),
+                                    proto.is_cross_program_prefetch());
+      break;
+    }
     case HloOpcode::kCompare: {
       // Auto-upgraded from deprecated opcode skips the following.
       if (!comparison_direction) {
@@ -839,7 +844,6 @@
     case HloOpcode::kCeil:
     case HloOpcode::kCollectivePermuteDone:
     case HloOpcode::kCopy:
-    case HloOpcode::kCopyStart:
     case HloOpcode::kCopyDone:
     case HloOpcode::kCos:
     case HloOpcode::kClz:
@@ -946,6 +950,13 @@
                                               fft_length);
 }
 
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
+    const Shape& shape, HloInstruction* operand,
+    bool is_cross_program_prefetch) {
+  return absl::make_unique<HloCopyStartInstruction>(shape, operand,
+                                                    is_cross_program_prefetch);
+}
+
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
     ComparisonDirection direction, absl::optional<Comparison::Type> type) {
@@ -4118,6 +4129,10 @@
   return Cast<HloDomainInstruction>(this)->user_side_metadata();
 }
 
+bool HloInstruction::is_cross_program_prefetch() const {
+  return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
+}
+
 ComparisonDirection HloInstruction::comparison_direction() const {
   return Cast<HloCompareInstruction>(this)->direction();
 }
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e9dca14..7db128b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -592,6 +592,12 @@
       const Shape& shape, HloInstruction* operand, FftType fft_type,
       absl::Span<const int64> fft_length);
 
+  // Creates a copy-start op, indicating whether this is a cross-program
+  // prefetch or not.
+  static std::unique_ptr<HloInstruction> CreateCopyStart(
+      const Shape& shape, HloInstruction* operand,
+      bool is_cross_program_prefetch = false);
+
   // Creates a compare op, performing the comparison specified in direction.
   static std::unique_ptr<HloInstruction> CreateCompare(
       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
@@ -1865,6 +1871,9 @@
   // Delegates to HloDomainInstruction::user_side_metadata().
   const DomainMetadata& user_side_metadata() const;
 
+  // Delegates to HloCopyStartInstruction::is_cross_program_prefetch().
+  bool is_cross_program_prefetch() const;
+
   // Delegates to HloCompareInstruction::direction().
   ComparisonDirection comparison_direction() const;
 
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index d378bef..df225e2 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -204,6 +204,47 @@
                                               fft_length_);
 }
 
+HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
+                                                 HloInstruction* operand,
+                                                 bool is_cross_program_prefetch)
+    : HloInstruction(HloOpcode::kCopyStart, shape),
+      is_cross_program_prefetch_(is_cross_program_prefetch) {
+  AppendOperand(operand);
+}
+
+HloInstructionProto HloCopyStartInstruction::ToProto() const {
+  HloInstructionProto proto = HloInstruction::ToProto();
+  proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
+  return proto;
+}
+
+std::vector<string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
+    const HloPrintOptions& options) const {
+  std::vector<string> result;
+  if (is_cross_program_prefetch()) {
+    result.push_back("is_cross_program_prefetch=true");
+  }
+  return result;
+}
+
+bool HloCopyStartInstruction::IdenticalSlowPath(
+    const HloInstruction& other,
+    const std::function<bool(const HloComputation*, const HloComputation*)>&
+        eq_computations) const {
+  const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
+  return is_cross_program_prefetch() ==
+         casted_other.is_cross_program_prefetch();
+}
+
+std::unique_ptr<HloInstruction>
+HloCopyStartInstruction::CloneWithNewOperandsImpl(
+    const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+    HloCloneContext* context) const {
+  CHECK_EQ(new_operands.size(), 1);
+  return absl::make_unique<HloCopyStartInstruction>(
+      shape, new_operands[0], is_cross_program_prefetch());
+}
+
 HloCompareInstruction::HloCompareInstruction(
     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
     ComparisonDirection direction, absl::optional<Comparison::Type> type)
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index fd2b0b7..17368e8 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -132,6 +132,28 @@
   std::vector<int64> fft_length_;
 };
 
+class HloCopyStartInstruction : public HloInstruction {
+ public:
+  explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
+                                   bool is_cross_program_prefetch);
+
+  bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
+  HloInstructionProto ToProto() const override;
+
+ private:
+  std::vector<string> ExtraAttributesToStringImpl(
+      const HloPrintOptions& options) const override;
+  bool IdenticalSlowPath(
+      const HloInstruction& other,
+      const std::function<bool(const HloComputation*, const HloComputation*)>&
+          eq_computations) const override;
+  std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+      const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+      HloCloneContext* context) const override;
+
+  bool is_cross_program_prefetch_;
+};
+
 class HloCompareInstruction : public HloInstruction {
  public:
   explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index cb5cbd0..9c6509d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -276,10 +276,10 @@
       /*element_size_in_bits=*/0, /*memory_space=*/2);
 
   auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0");
-  auto copy_start = HloInstruction::CreateUnary(
+  auto copy_start = HloInstruction::CreateCopyStart(
       ShapeUtil::MakeTupleShape(
           {shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}),
-      HloOpcode::kCopyStart, p0.get());
+      p0.get());
   auto copy_done = HloInstruction::CreateUnary(
       shape_memspace2, HloOpcode::kCopyDone, copy_start.get());
 
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index b5680b4..e2bbda3 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -883,7 +883,6 @@
     case HloOpcode::kClz:
     case HloOpcode::kCollectivePermuteDone:
     case HloOpcode::kCopy:
-    case HloOpcode::kCopyStart:
     case HloOpcode::kCopyDone:
     case HloOpcode::kCos:
     case HloOpcode::kExp:
@@ -1091,6 +1090,20 @@
       }
       break;
     }
+    case HloOpcode::kCopyStart: {
+      // If the is_cross_program_prefetch attribute is not present then default
+      // to false.
+      optional<bool> is_cross_program_prefetch = false;
+      attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
+                                            &is_cross_program_prefetch};
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
+          shape, operands[0], *is_cross_program_prefetch));
+      break;
+    }
     case HloOpcode::kReplicaId: {
       if (!ParseOperands(&operands, /*expected_size=*/0) ||
           !ParseAttributes(attrs)) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index aba6aef..620e67c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -318,7 +318,7 @@
 
 ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
   %v1 = f32[] parameter(0)
-  %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1)
+  %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true
   %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
   %v2 = f32[2,3]{1,0:S(1)} parameter(1)
   %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h
index a22a394..1de231a 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_fix.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h
@@ -43,11 +43,12 @@
     while (changed_this_iteration) {
       TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
       changed |= changed_this_iteration;
-      VLOG(3) << "changed_this_iteration: " << changed_this_iteration;
+      VLOG(3) << Pass::name() << " iteration " << iteration_count
+              << " changed_this_iteration: " << changed_this_iteration;
       ++iteration_count;
       if (iteration_count == kLimit) {
-        VLOG(1) << "Unexpectedly high number of iterations in HLO passes, "
-                   "exiting fixed point loop.";
+        VLOG(1) << "Unexpectedly high number of iterations in HLO passes '"
+                << Pass::name() << "' exiting fixed point loop.";
         // Return false in case this is fixed point is nested.
         return false;
       }
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index b07ab10..3b7b0b6 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -69,6 +69,9 @@
     }
     TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
     changed |= pass_changed;
+    if (pass_changed) {
+      VLOG(3) << "  Pass caused changes" << pass->name();
+    }
     TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
     last_pass_name = string(pass_name);
     if (!pass->IsPassPipeline()) {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc
index e1e506b..da4e3d6 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc
@@ -347,13 +347,21 @@
       index_dim++;
     }
   }
+
+  if (index_sharding.ReplicateOnLastTileDim()) {
+    output_tile_assignment_dims.push_back(
+        index_sharding.tile_assignment().dimensions().back());
+  }
+
   Array<int64> new_tile_assignment = index_sharding.tile_assignment();
   if (new_tile_assignment.num_elements() !=
       Product(output_tile_assignment_dims)) {
     return HloSharding::Replicate();
   }
   new_tile_assignment.Reshape(output_tile_assignment_dims);
-  return HloSharding::Tile(new_tile_assignment);
+  return index_sharding.ReplicateOnLastTileDim()
+             ? HloSharding::PartialTile(new_tile_assignment)
+             : HloSharding::Tile(new_tile_assignment);
 }
 
 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
@@ -379,13 +387,20 @@
         index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
   }
 
+  if (output_sharding.ReplicateOnLastTileDim()) {
+    index_tile_assignment_dims.push_back(
+        output_sharding.tile_assignment().dimensions().back());
+  }
+
   Array<int64> new_tile_assignment = output_sharding.tile_assignment();
   if (new_tile_assignment.num_elements() !=
       Product(index_tile_assignment_dims)) {
     return HloSharding::Replicate();
   }
   new_tile_assignment.Reshape(index_tile_assignment_dims);
-  return HloSharding::Tile(new_tile_assignment);
+  return output_sharding.ReplicateOnLastTileDim()
+             ? HloSharding::PartialTile(new_tile_assignment)
+             : HloSharding::Tile(new_tile_assignment);
 }
 
 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
@@ -455,13 +470,19 @@
   if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
     index_tile_assignment_dims.push_back(1);
   }
+  if (data_sharding.ReplicateOnLastTileDim()) {
+    index_tile_assignment_dims.push_back(
+        data_sharding.tile_assignment().dimensions().back());
+  }
   Array<int64> new_tile_assignment = data_sharding.tile_assignment();
   if (new_tile_assignment.num_elements() !=
       Product(index_tile_assignment_dims)) {
     return HloSharding::Replicate();
   }
   new_tile_assignment.Reshape(index_tile_assignment_dims);
-  return HloSharding::Tile(new_tile_assignment);
+  return data_sharding.ReplicateOnLastTileDim()
+             ? HloSharding::PartialTile(new_tile_assignment)
+             : HloSharding::Tile(new_tile_assignment);
 }
 
 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
@@ -481,13 +502,19 @@
       index_dim++;
     }
   }
+  if (index_sharding.ReplicateOnLastTileDim()) {
+    data_tile_assignment_dims.push_back(
+        index_sharding.tile_assignment().dimensions().back());
+  }
   Array<int64> new_tile_assignment = index_sharding.tile_assignment();
   if (new_tile_assignment.num_elements() !=
       Product(data_tile_assignment_dims)) {
     return HloSharding::Replicate();
   }
   new_tile_assignment.Reshape(data_tile_assignment_dims);
-  return HloSharding::Tile(new_tile_assignment);
+  return index_sharding.ReplicateOnLastTileDim()
+             ? HloSharding::PartialTile(new_tile_assignment)
+             : HloSharding::Tile(new_tile_assignment);
 }
 
 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
@@ -614,9 +641,15 @@
     }
     passthrough_tile[offset_dim] = dim_partitions;
   }
+  if (operand_sharding.ReplicateOnLastTileDim()) {
+    passthrough_tile.push_back(
+        operand_sharding.tile_assignment().dimensions().back());
+  }
   Array<int64> tile_assignment = operand_sharding.tile_assignment();
   tile_assignment.Reshape(passthrough_tile);
-  return HloSharding::Tile(tile_assignment);
+  return operand_sharding.ReplicateOnLastTileDim()
+             ? HloSharding::PartialTile(tile_assignment)
+             : HloSharding::Tile(tile_assignment);
 }
 
 // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
@@ -650,12 +683,19 @@
     }
     passthrough_tile[i] = dim_partitions;
   }
+
+  if (update_or_gather_sharding.ReplicateOnLastTileDim()) {
+    passthrough_tile.push_back(
+        update_or_gather_sharding.tile_assignment().dimensions().back());
+  }
   Array<int64> tile_assignment = update_or_gather_sharding.tile_assignment();
   if (tile_assignment.num_elements() != Product(passthrough_tile)) {
     return absl::nullopt;
   }
   tile_assignment.Reshape(passthrough_tile);
-  return HloSharding::Tile(tile_assignment);
+  return update_or_gather_sharding.ReplicateOnLastTileDim()
+             ? HloSharding::PartialTile(tile_assignment)
+             : HloSharding::Tile(tile_assignment);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index b290b1b..2085b1e 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -516,11 +516,12 @@
         continue;
       }
 
-      VLOG(5) << "Considering fusion of: " << instruction->ToString();
       std::vector<int64>& sorted_operand_numbers = next_entry.second;
 
       for (int64 i : sorted_operand_numbers) {
         HloInstruction* operand = instruction->mutable_operand(i);
+        VLOG(5) << "Considering fusion of: " << instruction->ToString()
+                << " with operand " << operand->name();
 
         if (!operand->IsFusible()) {
           VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible";
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index adccda7..55569cf 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1891,7 +1891,7 @@
             ? ShapeUtil::GetSubshape(instruction->literal().shape(),
                                      buffer.index())
                   .layout()
-            : LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
+            : GetUnconstrainedLayout(buffer);
     TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
                                                    /*mandatory=*/false));
 
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index a04d056..def620b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -27,6 +27,7 @@
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/service/call_graph.h"
 #include "tensorflow/compiler/xla/service/computation_layout.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -338,6 +339,9 @@
       const ResultLayoutConstraint& layout_constraint,
       LayoutConstraints* constraints);
 
+  virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) {
+    return LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
+  }
   // Called after layouts of an instruction have been finalized to allow
   // subclasses to check for platform specific assumptions.
   virtual Status Verify(const HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index 0371ce7..7816949 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -244,16 +244,7 @@
       } else {
         total = 0;
         for (const auto* user : indexing_users[instruction]) {
-          int64 weight = 1;
-          // Concatenate is special: the index differs for each operand, so
-          // in the worst case we have to deal with as many index values as
-          // the number of operands of Concatenate. By considering the worst
-          // case, we are more conservative than necessary regarding
-          // refusing to fuse.
-          if (user->opcode() == HloOpcode::kConcatenate) {
-            weight = user->operand_count();
-          }
-          total += index_usage_count[user] * weight;
+          total += index_usage_count[user];
         }
       }
       for (const auto* operand : instruction->operands()) {
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index f3957b2..12cdb17 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -236,15 +236,26 @@
 }
 
 int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
-    const HloUse& use, int64 start_time, int64 end_time) const {
+    const Shape& shape, int64 start_time, int64 end_time,
+    const HloUse* use) const {
   return end_time - min_overlap_count_;
 }
 
+int64 InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime(
+    const Shape& shape, int64 earliest_prefetch_start_time,
+    int64 latest_prefetch_start_time, int64 prefetch_end_time) const {
+  return std::max(earliest_prefetch_start_time,
+                  prefetch_end_time - max_overlap_count_);
+}
+
 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
                                                    int64 start_time,
                                                    int64 end_time) {
   end_time_ = end_time;
-  current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_);
+  const Shape& shape = ShapeUtil::GetSubshape(
+      use.instruction->operand(use.operand_number)->shape(), use.operand_index);
+  current_prefetch_time_ =
+      PreferredPrefetchStartTime(shape, start_time, end_time, end_time);
 }
 
 int64 InstructionCountPrefetchIntervalPicker::Next() {
@@ -361,18 +372,22 @@
 }
 
 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
-    const HloUse& use, int64 start_time, int64 end_time) const {
-  const Shape& shape = ShapeUtil::GetSubshape(
-      use.instruction->operand(use.operand_number)->shape(), use.operand_index);
+    const Shape& shape, int64 start_time, int64 end_time,
+    const HloUse* use) const {
   // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
-  // Estimate the time we would save by having this op in alternate memory.
-  float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
-  float elapsed_time_in_alternate_mem =
-      cost_analysis_.GetInstructionElapsedInAlternateMemory(
-          *use.instruction, use.operand_number,
-          /*output_in_alternate_mem=*/false);
-  float inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
+  // If there is a use, estimate the time we would save by having this op in
+  // alternate memory.
+  float inst_elapsed_reduction = 0.0f;
+  if (use) {
+    float elapsed_time =
+        cost_analysis_.GetInstructionElapsed(*use->instruction);
+    float elapsed_time_in_alternate_mem =
+        cost_analysis_.GetInstructionElapsedInAlternateMemory(
+            *use->instruction, use->operand_number,
+            /*output_in_alternate_mem=*/false);
+    inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
+  }
   int end_nest_level = while_nest_level_[end_time];
 
   // Find the latest time we're allowed to start prefetching.
@@ -390,6 +405,33 @@
   return latest_prefetch_time;
 }
 
+int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime(
+    const Shape& shape, int64 earliest_prefetch_start_time,
+    int64 latest_prefetch_start_time, int64 prefetch_end_time) const {
+  // Between the earliest and latest prefetch interval, find the interval
+  // closest to the preferred interval and start iterating from there.
+  float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
+  int64 preferred_prefetch_start_time = earliest_prefetch_start_time;
+  float preferred_interval =
+      preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed;
+  float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time,
+                                                  prefetch_end_time);
+  int end_nest_level = while_nest_level_[prefetch_end_time];
+  for (int64 prefetch_start_time = earliest_prefetch_start_time + 1;
+       prefetch_start_time <= latest_prefetch_start_time;
+       ++prefetch_start_time) {
+    float interval =
+        GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time);
+    if (while_nest_level_[prefetch_start_time] == end_nest_level &&
+        std::abs(preferred_interval - interval) <
+            std::abs(preferred_interval - best_interval)) {
+      best_interval = interval;
+      preferred_prefetch_start_time = prefetch_start_time;
+    }
+  }
+  return preferred_prefetch_start_time;
+}
+
 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
     int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const {
   // Iterate towards the beginning until we find a suitable end time that is the
@@ -422,7 +464,8 @@
 
   // Find the latest time we're allowed to start prefetching.
   float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
-  latest_prefetch_time_ = LatestPrefetchStartTime(use, start_time, end_time);
+  latest_prefetch_time_ =
+      LatestPrefetchStartTime(shape, start_time, end_time, &use);
 
   // Find the earliest time we're allowed to start prefetching.
   float max_interval = max_async_copy_to_overlap_ratio_ *
@@ -443,24 +486,10 @@
     return;
   }
 
-  // Between the earliest and latest prefetch interval, find the interval
-  // closest to the preferred interval and start iterating from there.
-  int64 starting_prefetch_time = earliest_prefetch_time_;
+  int64 starting_prefetch_time = PreferredPrefetchStartTime(
+      shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_);
   float preferred_interval =
       preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
-  float best_interval =
-      GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_);
-  for (int64 prefetch_time = earliest_prefetch_time_ + 1;
-       prefetch_time <= latest_prefetch_time_; ++prefetch_time) {
-    float interval =
-        GetLogicalIntervalElapsed(prefetch_time, end_logical_time_);
-    if (while_nest_level_[prefetch_time] == end_nest_level &&
-        std::abs(preferred_interval - interval) <
-            std::abs(preferred_interval - best_interval)) {
-      best_interval = interval;
-      starting_prefetch_time = prefetch_time;
-    }
-  }
   VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
           << max_interval << " " << preferred_interval
           << " prefetch time earliest/latest/starting = "
@@ -1244,10 +1273,13 @@
         }
       }
 
-      // Bitcasts don't define buffers and don't directly consume buffers.  Skip
-      // allocating buffers for bitcast uses. The uses that feed from bitcasts
-      // will be handled specially.
-      if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) {
+      // Bitcasts don't define buffers and don't directly consume buffers. Skip
+      // allocating buffers for bitcast uses (unless they are the root
+      // instruction). The uses that feed from bitcasts will be handled
+      // specially.
+      if (hlo_use.instruction->opcode() != HloOpcode::kBitcast ||
+          hlo_use.instruction ==
+              hlo_use.instruction->parent()->root_instruction()) {
         AllocationRequest request;
         // Rarely, (e.g., when conditional true and false parameters are the
         // same), definition time can be the time of the conditional and use
@@ -1371,32 +1403,79 @@
   // Find the earliest use.
   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
   auto uses = buffer->uses();
-  auto first_use =
-      absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) {
-        return instruction_schedule.at(lhs.instruction) <
-               instruction_schedule.at(rhs.instruction);
-      });
+  auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
+    return instruction_schedule.at(lhs.instruction) <
+           instruction_schedule.at(rhs.instruction);
+  };
+  auto first_use = absl::c_min_element(uses, use_schedule_compare);
   int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction);
 
+  // Find the latest use time.
+  int64 last_use_time = instruction_schedule.at(
+      absl::c_max_element(uses, use_schedule_compare)->instruction);
+  for (const HloValue* colocation : prefetch_candidate->colocations) {
+    last_use_time = std::max(
+        last_use_time,
+        instruction_schedule.at(
+            absl::c_max_element(colocation->uses(), use_schedule_compare)
+                ->instruction));
+  }
+
+  int64 end_of_program_prefetch_end_time = instruction_schedule.size() - 1;
+  int64 end_of_program_prefetch_start_time =
+      options_.prefetch_interval_picker->PreferredPrefetchStartTime(
+          buffer->defining_position().shape(), last_use_time,
+          end_of_program_prefetch_end_time, end_of_program_prefetch_end_time);
+  VLOG(2) << "last use time = " << last_use_time
+          << ", end-of-program prefetch start time = "
+          << end_of_program_prefetch_start_time;
+  bool free_buffer =
+      (end_of_program_prefetch_start_time > last_use_time &&
+       end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
+  int64 cross_program_prefetch_end_time =
+      free_buffer ? last_use_time : prefetch_candidate->end;
+
   AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
                chunk_candidate.chunk, prefetch_candidate->start,
-               prefetch_candidate->end, latest_prefetch_time, &allocations);
+               cross_program_prefetch_end_time, latest_prefetch_time,
+               &allocations,
+               /*is_cross_program_prefetch=*/true);
   absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
+  int64 cross_program_prefetch_offset = allocations.back()->chunk().offset;
+
+  if (free_buffer) {
+    VLOG(2) << "Adding an end-of-program prefetch for freed "
+               "cross-program-prefetched buffer.";
+    AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate,
+                 chunk_candidate.chunk, end_of_program_prefetch_start_time,
+                 end_of_program_prefetch_end_time,
+                 end_of_program_prefetch_end_time, &allocations);
+    CHECK_EQ(cross_program_prefetch_offset, allocations.back()->chunk().offset);
+  }
+
   for (auto& allocation : allocations) {
     allocations_->push_back(std::move(allocation));
   }
-  // Add a repack allocation block for the Allocation object in alternate
+
+  // Add a repack allocation block for the Allocation objects in alternate
   // memory.
-  CHECK_EQ(allocations_->size(), 2);
-  MemorySpaceAssignment::Allocation* last_allocation =
-      allocations_->at(1).get();
-  CHECK(last_allocation->memory_space() == MemorySpace::kAlternate);
-  repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
-      last_allocation->start_time(), last_allocation->end_time(),
-      last_allocation->chunk().size, last_allocation->chunk().offset,
-      static_cast<int64>(repack_allocation_blocks_.size()), last_allocation));
-  repack_allocation_blocks_.back().colocations.push_back(
-      &repack_allocation_blocks_.back());
+  CHECK_EQ(repack_allocation_blocks_.size(), 0);
+  for (const auto& allocation : *allocations_) {
+    if (allocation->memory_space() == MemorySpace::kAlternate) {
+      repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
+          allocation->start_time(), allocation->end_time(),
+          allocation->chunk().size, allocation->chunk().offset,
+          static_cast<int64>(repack_allocation_blocks_.size()),
+          allocation.get()));
+      RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
+      for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
+        colocation.colocations.push_back(inserted);
+        if (&colocation != inserted) {
+          inserted->colocations.push_back(&colocation);
+        }
+      }
+    }
+  }
 
   ClearPendingChunks();
 }
@@ -1858,7 +1937,8 @@
     const MemorySpaceAssignment::Allocation& prev_allocation,
     MemorySpace memory_space, absl::optional<Chunk> chunk, int64 start_time,
     int64 end_time, int64 copy_done_schedule_before_time,
-    MemorySpaceAssignment::AllocationSequence* allocations) {
+    MemorySpaceAssignment::AllocationSequence* allocations,
+    bool is_cross_program_prefetch) {
   VLOG(3) << "Copy to "
           << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
                   ? "default"
@@ -1870,7 +1950,7 @@
   allocations->push_back(
       absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
           prev_allocation, memory_space, chunk, start_time, end_time,
-          copy_done_schedule_before_time));
+          copy_done_schedule_before_time, is_cross_program_prefetch));
 
   // Register the additional async copy with the interval tree to keep track of
   // the limit at any given time.
@@ -2132,12 +2212,15 @@
     const AllocationRequest& request, int64 earliest_prefetch_time) const {
   int64 prefetch_end_time = request.latest_prefetch_time;
 
+  const HloUse& use = request.use->hlo_use;
+  const Shape& shape = ShapeUtil::GetSubshape(
+      use.instruction->operand(use.operand_number)->shape(), use.operand_index);
   for (int retry_number = 0;
        retry_number < options_.prefetch_copy_done_reorder_max_retries;
        ++retry_number) {
     int64 latest_prefetch_time =
         options_.prefetch_interval_picker->LatestPrefetchStartTime(
-            request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
+            shape, earliest_prefetch_time, prefetch_end_time, &use);
     VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time
             << ", earliest prefetch start time = " << earliest_prefetch_time
             << ", prefetch end time = " << prefetch_end_time;
@@ -2444,7 +2527,9 @@
     const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range,
     const MemorySpaceAssignment::Options& options) {
   std::vector<MemorySpaceAssignment::BufferInterval> candidates;
-  for (HloValue* value : alias_analysis.dataflow_analysis().values()) {
+  for (const HloBuffer& buffer : alias_analysis.buffers()) {
+    CHECK_GE(buffer.values().size(), 1);
+    const HloValue* value = buffer.values().at(0);
     if (IsCrossProgramPrefetchCandidate(*value, options)) {
       MemorySpaceAssignment::BufferInterval interval;
       interval.buffer = value;
@@ -2452,6 +2537,7 @@
       interval.start = 0;
       interval.end = hlo_live_range.schedule_end_time();
       interval.need_allocation = true;
+      interval.colocations = {++buffer.values().begin(), buffer.values().end()};
       candidates.emplace_back(interval);
     }
   }
@@ -2681,9 +2767,9 @@
   Shape shape = defining_position().shape();
   HloInstruction* producing_instruction = AddGetTupleElements();
   HloComputation* computation = producing_instruction->parent();
-  copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
+  copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
       ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
-      HloOpcode::kCopyStart, producing_instruction));
+      producing_instruction, is_cross_program_prefetch_));
   copy_done_ = computation->AddInstruction(
       HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
   VLOG(4) << "Created " << copy_start_->name()
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 577554a..0473766 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -200,8 +200,15 @@
                                          int64 latest_end_time) const = 0;
 
   // Returns the latest time that a prefetch can start.
-  virtual int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
-                                        int64 end_time) const = 0;
+  virtual int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time,
+                                        int64 end_time,
+                                        const HloUse* use) const = 0;
+
+  // Returns the preferred time that a prefetch can start.
+  virtual int64 PreferredPrefetchStartTime(const Shape& shape,
+                                           int64 earliest_prefetch_start_time,
+                                           int64 latest_prefetch_start_time,
+                                           int64 prefetch_end_time) const = 0;
 
   // Returns the latest time that a prefetch can end that is less than or equal
   // to proposed_prefetch_end_time.
@@ -269,8 +276,14 @@
   int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
                                  int64 latest_end_time) const override;
 
-  int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
-                                int64 end_time) const override;
+  int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time,
+                                int64 end_time,
+                                const HloUse* use) const override;
+
+  int64 PreferredPrefetchStartTime(const Shape& shape,
+                                   int64 earliest_prefetch_start_time,
+                                   int64 latest_prefetch_start_time,
+                                   int64 prefetch_end_time) const override;
 
   void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
 
@@ -308,11 +321,18 @@
   int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
                                  int64 latest_end_time) const override;
 
-  int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time,
-                                int64 end_time) const override;
   int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
                               int64 proposed_prefetch_end_time) const override;
 
+  int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time,
+                                int64 end_time,
+                                const HloUse* use) const override;
+
+  int64 PreferredPrefetchStartTime(const Shape& shape,
+                                   int64 earliest_prefetch_start_time,
+                                   int64 latest_prefetch_start_time,
+                                   int64 prefetch_end_time) const override;
+
   void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
 
   int64 Next() override;
@@ -561,12 +581,14 @@
    public:
     CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
                    absl::optional<Chunk> chunk, int64 start_time,
-                   int64 end_time, int64 copy_done_schedule_before_time)
+                   int64 end_time, int64 copy_done_schedule_before_time,
+                   bool is_cross_program_prefetch = false)
         : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
                      start_time, end_time),
           prev_allocation_(prev_allocation),
           copy_start_schedule_after_(start_time),
-          copy_done_schedule_before_(copy_done_schedule_before_time) {}
+          copy_done_schedule_before_(copy_done_schedule_before_time),
+          is_cross_program_prefetch_(is_cross_program_prefetch) {}
 
     bool is_copy_allocation() const override { return true; }
 
@@ -606,6 +628,10 @@
       copy_start_schedule_after_ = copy_start_schedule_after;
     }
 
+    bool is_cross_program_prefetch() const {
+      return is_cross_program_prefetch_;
+    }
+
     bool operator==(const CopyAllocation& other) const;
     std::string ToString() const override;
 
@@ -617,6 +643,7 @@
     // is before copy_done_schedule_before_.
     int64 copy_start_schedule_after_;
     int64 copy_done_schedule_before_;
+    bool is_cross_program_prefetch_;
     HloInstruction* copy_start_;
     HloInstruction* copy_done_;
   };
@@ -1188,7 +1215,8 @@
                     MemorySpace memory_space, absl::optional<Chunk> chunk,
                     int64 start_time, int64 end_time,
                     int64 copy_done_schedule_before_time,
-                    MemorySpaceAssignment::AllocationSequence* allocations);
+                    MemorySpaceAssignment::AllocationSequence* allocations,
+                    bool is_cross_program_prefetch = false);
 
   // This method is used for committing the chunk candidate but adding it to
   // pending_chunks_ so that we can "uncommit" them in case we need to roll back
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 22acc17..f9ca0f8 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -4066,6 +4066,51 @@
             find_schedule_index(cos->operand(0)));
 }
 
+TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
+  // Tests against a bug where the root of entry computation is a bitcast
+  // instruction and it ends up getting an allocation in the alternate memory.
+  absl::string_view hlo_string = R"(
+HloModule primitive_computation_gather.4, is_scheduled=true
+
+%while_body {
+  %param.1 = (s32[], f32[3,3,3]) parameter(0)
+  %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0
+  %copy.6 = s32[] copy(s32[] %get-tuple-element.32)
+  %constant.8 = s32[] constant(1)
+  %add = s32[] add(s32[] %copy.6, s32[] %constant.8)
+  %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1
+  negate = f32[3,3,3] negate(get-tuple-element.35)
+  ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate)
+}
+
+%while_cond {
+  %param.0 = (s32[], f32[3,3,3]) parameter(0)
+  %get-tuple-element = s32[] get-tuple-element(%param.0), index=0
+  %constant.3 = s32[] constant(3)
+  ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT
+}
+
+ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] {
+  %constant.1 = s32[] constant(0)
+  %copy.11 = s32[] copy(s32[] %constant.1)
+  %constant = f32[] constant(0)
+  %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={}
+  %tuple.8 = (s32[], f32[3,10,5], s32[3,1], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast)
+  %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body
+  %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1
+  ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7)
+}
+  )";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  AssignMemorySpace(module.get());
+
+  const HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_TRUE(!root->shape().has_layout() ||
+              root->shape().layout().memory_space() == kDefaultMemorySpace);
+}
+
 // A mock MemorySpaceAssignmentRepacker class that accepst a map of
 // (start_time,offset) -> new_offset values. Using this map, the repacker
 // repacks the allocations to the new_offset.
@@ -4566,6 +4611,125 @@
   EXPECT_EQ(cross_program_prefetches.size(), 0);
 }
 
+TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) {
+  // This test is for checking if the cross-program-prefetched buffer is freed
+  // after its last use and there is an end-of-program prefetch.
+  absl::string_view hlo_string = R"(
+  HloModule cross_program_prefetch, is_scheduled=true
+
+  ENTRY CrossProgramPrefetch {
+    p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
+    get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
+    get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
+    dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    negate.1 = f32[8,2]{1,0} negate(dot)
+    negate.2 = f32[8,2]{1,0} negate(negate.1)
+    negate.3 = f32[8,2]{1,0} negate(negate.2)
+    negate.4 = f32[8,2]{1,0} negate(negate.3)
+    negate.5 = f32[8,2]{1,0} negate(negate.4)
+    negate.6 = f32[8,2]{1,0} negate(negate.5)
+    negate.7 = f32[8,2]{1,0} negate(negate.6)
+    negate.8 = f32[8,2]{1,0} negate(negate.7)
+    ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
+                    /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
+
+  auto cross_program_prefetches = module->CrossProgramPrefetches();
+  EXPECT_EQ(cross_program_prefetches.size(), 1);
+  if (!cross_program_prefetches.empty()) {
+    EXPECT_EQ(cross_program_prefetches[0].first, 0);
+    EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
+  }
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
+      HloDataflowAnalysis::Run(*module));
+  const HloValue& cross_program_prefetched_value =
+      dataflow_analysis->GetValueDefinedAt(
+          module->entry_computation()->parameter_instruction(0), {1});
+  // Expect that there are two prefetches that use this value, one is the
+  // cross-program prefetch, the other is the end-of-program prefetch.
+  auto is_cross_program_prefetch = [](const HloUse& use) {
+    return use.instruction->opcode() == HloOpcode::kCopyStart &&
+           use.instruction->is_cross_program_prefetch();
+  };
+  EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
+                             is_cross_program_prefetch),
+            1);
+  auto is_end_of_program_prefetch = [](const HloUse& use) {
+    return use.instruction->opcode() == HloOpcode::kCopyStart &&
+           !use.instruction->is_cross_program_prefetch();
+  };
+  EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
+                             is_end_of_program_prefetch),
+            1);
+}
+
+TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
+  // This tests the scenario that the cross-program-prefetched buffer is used
+  // again close to the end of the computation. In this case, it is better not
+  // to free the buffer.
+  absl::string_view hlo_string = R"(
+  HloModule cross_program_prefetch, is_scheduled=true
+
+  ENTRY CrossProgramPrefetch {
+    p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
+    get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
+    get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
+    dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    negate.1 = f32[8,2]{1,0} negate(dot)
+    negate.2 = f32[8,2]{1,0} negate(negate.1)
+    negate.3 = f32[8,2]{1,0} negate(negate.2)
+    negate.4 = f32[8,2]{1,0} negate(negate.3)
+    negate.5 = f32[8,2]{1,0} negate(negate.4)
+    negate.6 = f32[8,2]{1,0} negate(negate.5)
+    negate.7 = f32[8,2]{1,0} negate(negate.6)
+    negate.8 = f32[8,2]{1,0} negate(negate.7)
+    ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+  }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+
+  AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
+                    /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
+
+  auto cross_program_prefetches = module->CrossProgramPrefetches();
+  EXPECT_EQ(cross_program_prefetches.size(), 1);
+  if (!cross_program_prefetches.empty()) {
+    EXPECT_EQ(cross_program_prefetches[0].first, 0);
+    EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
+  }
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
+      HloDataflowAnalysis::Run(*module));
+  const HloValue& cross_program_prefetched_value =
+      dataflow_analysis->GetValueDefinedAt(
+          module->entry_computation()->parameter_instruction(0), {1});
+  // Expect that there is one prefetch that use this value, the cross-program
+  // prefetch. There shouldn't be an end-of-program prefetch.
+  auto is_cross_program_prefetch = [](const HloUse& use) {
+    return use.instruction->opcode() == HloOpcode::kCopyStart &&
+           use.instruction->is_cross_program_prefetch();
+  };
+  EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
+                             is_cross_program_prefetch),
+            1);
+  auto is_end_of_program_prefetch = [](const HloUse& use) {
+    return use.instruction->opcode() == HloOpcode::kCopyStart &&
+           !use.instruction->is_cross_program_prefetch();
+  };
+  EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
+                             is_end_of_program_prefetch),
+            0);
+}
+
 using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
 
 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
@@ -4790,11 +4954,12 @@
 
   HloInstruction* root = module->entry_computation()->root_instruction();
   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
+  const Shape& shape = root->operand(1)->shape();
 
   // We expect the root's latest prefetch start time to be before the while loop
   // (logical time 4).
-  EXPECT_EQ(interval_picker.LatestPrefetchStartTime(use, /*start_time=*/0,
-                                                    /*end_time=*/23),
+  EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
+                                                    /*end_time=*/23, &use),
             4);
 }
 
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
index 68bcde4..af670eb 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD
+++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
@@ -41,9 +41,12 @@
     srcs = ["emission_context.cc"],
     hdrs = ["emission_context.h"],
     deps = [
+        "//tensorflow/compiler/mlir/hlo",
+        "//tensorflow/compiler/mlir/hlo:lhlo",
         "//tensorflow/compiler/xla/service:hlo",
         "@com_google_absl//absl/strings",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:StandardOps",
     ],
 )
 
@@ -182,7 +185,6 @@
     deps = [
         ":passes",
         "//tensorflow/compiler/mlir/hlo",
-        "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
         "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
         "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation",
         "//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
index ca97926..06c7ebd 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
@@ -16,8 +16,11 @@
 #include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
 
 #include "absl/strings/substitute.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 
 namespace xla {
@@ -25,6 +28,8 @@
 
 EmissionContext::EmissionContext(std::unique_ptr<HloModule> module)
     : module_(std::move(module)), context_() {
+  context_.loadDialect<mlir::mhlo::MhloDialect, mlir::lmhlo::LmhloDialect,
+                       mlir::StandardOpsDialect>();
   error_handler_ = [](const ErrorMap& instructions_with_error,
                       HloModule* module) {
     std::set<const HloComputation*> computations_with_error;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc
index d5cad38..f7a7decf 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc
@@ -46,6 +46,7 @@
       hlo_module.entry_computation()->root_instruction();
 
   mlir::MLIRContext context;
+  context.loadAllGloballyRegisteredDialects();
   mlir::OwningModuleRef mlir_module(
       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)));
 
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index ae99cc9..1b2edec 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/SCF/Passes.h"  // from @llvm-project
 #include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
 #include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/BufferPlacement.h"  // from @llvm-project
@@ -143,6 +144,10 @@
 class LowerToNVVMPass
     : public ::mlir::PassWrapper<
           LowerToNVVMPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
+  void getDependentDialects(mlir::DialectRegistry& registry) const override {
+    registry.insert<mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
+  }
+
  public:
   void runOnOperation() override {
     ::mlir::gpu::GPUModuleOp m = getOperation();
diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
index 03c77c2..b89e84e 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
@@ -1774,6 +1774,28 @@
               op::Sharding("{devices=[2,1]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, GatherFromIndex_PartialReplicate) {
+  const char* hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0), sharding={replicated}
+  %indices = s32[3] parameter(1),
+   sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
+  %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
+    collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
+    slice_sizes={1,9}
+  ROOT %copy = f32[3,9] copy(%gather)
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "gather"),
+              op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, GatherFromDataOperand) {
   const char* hlo_string = R"(
 HloModule module
@@ -1795,6 +1817,28 @@
               op::Sharding("{devices=[1,2]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, GatherFromDataOperand_PartialReplicate) {
+  const char* hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0),
+    sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+  %indices = s32[3] parameter(1), sharding={replicated}
+  %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
+    collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
+    slice_sizes={1,9}
+  ROOT %copy = f32[3,9] copy(%gather)
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "gather"),
+              op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, GatherToIndex) {
   const char* hlo_string = R"(
 HloModule module
@@ -1816,6 +1860,28 @@
               op::Sharding("{devices=[2]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, GatherToIndex_PartialReplicate) {
+  const char* hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0), sharding={replicated}
+  %p1 = s32[3] parameter(1)
+  %indices = s32[3] copy(%p1)
+  ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
+    collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
+    slice_sizes={1,9},
+    sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "indices"),
+              op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, GatherToIndex2) {
   const char* hlo_string = R"(
 HloModule module
@@ -1839,6 +1905,30 @@
               op::Sharding("{devices=[1,2,1]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, GatherToIndex2_PartialReplicate) {
+  const char* hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = bf16[2,4819,4] parameter(0), sharding={replicated}
+  %p1 = s32[2,1000,2] parameter(1)
+  %indices = s32[2,1000,2] copy(%p1)
+  ROOT %gather = bf16[2,1000,4]
+    gather(bf16[2,4819,4] %input, s32[2,1000,2] %indices),
+    offset_dims={2}, collapsed_slice_dims={0,1},
+    start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1,4},
+    sharding={devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(
+      FindInstruction(module.get(), "indices"),
+      op::Sharding("{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, GatherToIndex3) {
   const char* hlo_string = R"(
 HloModule module
@@ -1883,6 +1973,27 @@
               op::Sharding("{devices=[1,2]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, GatherToDataOperand_PartialReplicate) {
+  const char* hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %p0 = f32[2,9] parameter(0)
+  %input = f32[2,9] copy(%p0)
+  %indices = s32[3] parameter(1), sharding={replicated}
+  ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
+    collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
+    slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "input"),
+              op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, DataOperandToScatter) {
   const char* const hlo_string = R"(
 HloModule module
@@ -1914,6 +2025,38 @@
               op::Sharding("{devices=[1,2]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, DataOperandToScatter_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0),
+   sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+  %indices = s32[3] parameter(1), sharding={replicated}
+  %updates = f32[3,9] parameter(2), sharding={replicated}
+  %scatter = f32[2,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1
+  ROOT %copy = f32[2,9] copy(%scatter)
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "scatter"),
+              op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, UpdateOperandToScatter) {
   const char* const hlo_string = R"(
 HloModule module
@@ -1945,6 +2088,70 @@
               op::Sharding("{devices=[1,2]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, UpdateOperandToScatter_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0), sharding={replicated}
+  %indices = s32[3] parameter(1), sharding={replicated}
+  %updates = f32[3,9] parameter(2),
+    sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+  %scatter = f32[2,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1
+  ROOT %copy = f32[2,9] copy(%scatter)
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "scatter"),
+              op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
+TEST_F(ShardingPropagationTest, ScatterToDataOperand_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %p0 = f32[2,9] parameter(0)
+  %input = f32[2,9] copy(%p0)
+  %indices = s32[3] parameter(1), sharding={replicated}
+  %updates = f32[3,9] parameter(2), sharding={replicated}
+  ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "input"),
+              op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, ScatterToDataOperand) {
   const char* const hlo_string = R"(
 HloModule module
@@ -1976,6 +2183,38 @@
               op::Sharding("{devices=[1,2]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, ScatterToUpdateOperand_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0)
+  %indices = s32[3] parameter(1), sharding={replicated}
+  %p2 = f32[3,9] parameter(2)
+  %updates = f32[3,9] copy(%p2)
+  ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "updates"),
+              op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) {
   const char* const hlo_string = R"(
 HloModule module
@@ -2038,6 +2277,38 @@
               op::Sharding("{devices=[2]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, ScatterUpdateToIndex_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0), sharding={replicated}
+  %p1 = s32[3] parameter(1), sharding={replicated}
+  %indices = s32[3] copy(%p1)
+  %updates = f32[3,9] parameter(2),
+    sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
+  ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1, sharding={replicated}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "indices"),
+              op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) {
   const char* const hlo_string = R"(
 HloModule module
@@ -2069,6 +2340,38 @@
               op::Sharding("{devices=[2,1]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, ScatterIndexToUpdate_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0), sharding={replicated}
+  %indices = s32[3] parameter(1),
+    sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
+  %p2 = f32[3,9] parameter(2), sharding={replicated}
+  %updates = f32[3,9] copy(%p2)
+  ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1, sharding={replicated}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "updates"),
+              op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}"));
+}
+
 TEST_F(ShardingPropagationTest, PartialShardingOnElementwise) {
   const char* const hlo_string = R"(
 HloModule module
diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD
index dd3da79..d2243d3 100644
--- a/tensorflow/compiler/xla/service/spmd/BUILD
+++ b/tensorflow/compiler/xla/service/spmd/BUILD
@@ -74,3 +74,16 @@
         "//tensorflow/core:test",
     ],
 )
+
+cc_library(
+    name = "schedule_aware_all_gather_cse",
+    srcs = ["schedule_aware_all_gather_cse.cc"],
+    hdrs = ["schedule_aware_all_gather_cse.h"],
+    deps = [
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_casting_utils",
+        "//tensorflow/compiler/xla/service:hlo_pass",
+        "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/container:flat_hash_map",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index 4075dc2..da43296 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -931,9 +931,13 @@
         AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims),
                         output_grouped, /*ignore_group_order=*/true);
     other = other.Reshard(UngroupSharding(other_grouped));
-    // TODO(yuanzx): Use reshard to replicate when ready.
     partially_replicated_other =
-        other.ReplicatePartial(other_grouped.group_dims);
+        other
+            .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
+                other.sharding(), other_grouped.group_dims))
+            .hlo();
+    top_level_sharding_to_reset.emplace_back(
+        partially_replicated_other, partially_replicated_other->sharding());
     partially_replicated_other->set_sharding(other_grouped.sharding);
   }
   auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc
new file mode 100644
index 0000000..cc97d5e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc
@@ -0,0 +1,132 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h"
+
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace xla {
+namespace {
+
+HloCollectiveInstruction* MayConsiderAsAllGather(HloInstruction* hlo,
+                                                 bool for_replicas) {
+  auto coll = DynCast<HloCollectiveInstruction>(hlo);
+  if (!coll) {
+    return nullptr;
+  }
+  if (coll->constrain_layout()) {
+    return nullptr;
+  }
+  if (for_replicas == coll->channel_id().has_value()) {
+    return nullptr;
+  }
+  if (coll->opcode() == HloOpcode::kAllGather) {
+    return coll;
+  }
+  // Consider broadcast -> dynamic-update-slice -> all-reduce as all-gather.
+  if (coll->opcode() == HloOpcode::kAllReduce && coll->shape().IsArray()) {
+    auto operand = coll->operand(0);
+    return operand->opcode() == HloOpcode::kDynamicUpdateSlice &&
+                   operand->operand(0)->opcode() == HloOpcode::kBroadcast
+               ? coll
+               : nullptr;
+  }
+  return nullptr;
+}
+
+StatusOr<bool> RunOnComputation(HloComputation* comp, bool for_replicas,
+                                int64 distance_threshold) {
+  // We consider estimate the live ranges of all-gathers by comparing their
+  // users' distance to the root, e.g., height.
+  absl::flat_hash_map<const HloInstruction*, int64> height;
+  auto ordered_hlos = comp->MakeInstructionPostOrder();
+  int64 max_height = 0;
+  for (auto it = ordered_hlos.rbegin(); it != ordered_hlos.rend(); ++it) {
+    auto hlo = *it;
+    int64 h = 0;
+    for (auto user : hlo->users()) {
+      h = std::max(h, height[user]) + 1;
+    }
+    max_height = std::max(max_height, h);
+    height[hlo] = h;
+  }
+
+  auto lowest_user_height = [&](const HloInstruction* hlo) {
+    int64 lowest = height[hlo];
+    for (auto user : hlo->users()) {
+      lowest = std::min(lowest, height[user]);
+    }
+    return lowest;
+  };
+
+  absl::flat_hash_map<const HloInstruction*,
+                      std::vector<HloCollectiveInstruction*>>
+      operand_to_ag;
+  bool changed = false;
+  for (auto hlo : ordered_hlos) {
+    auto ag = MayConsiderAsAllGather(hlo, for_replicas);
+    if (!ag) {
+      continue;
+    }
+
+    auto& earlier_ags = operand_to_ag[ag->operand(0)];
+    bool found = false;
+    int64 lowest_user_h = lowest_user_height(ag);
+    for (auto& eag : earlier_ags) {
+      auto old_channel_id = ag->channel_id();
+      if (eag->channel_id() && ag->channel_id()) {
+        ag->set_channel_id(eag->channel_id());
+      }
+      if (!eag->Identical(*ag)) {
+        ag->set_channel_id(old_channel_id);
+        continue;
+      }
+      found = true;
+      ag->set_channel_id(old_channel_id);
+      if (lowest_user_height(eag) > lowest_user_h + distance_threshold) {
+        eag = ag;
+        continue;
+      }
+      changed = true;
+      VLOG(1) << "Replacing " << ag->ToString() << " with " << eag->ToString();
+      TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(eag));
+      break;
+    }
+    if (!found) {
+      earlier_ags.push_back(ag);
+    }
+  }
+  return changed;
+}
+
+}  // namespace
+
+StatusOr<bool> ScheduleAwareAllGatherCSE::Run(HloModule* module) {
+  bool changed = false;
+  for (auto comp : module->computations()) {
+    TF_ASSIGN_OR_RETURN(
+        auto comp_changed,
+        RunOnComputation(comp, for_replicas_, distance_threshold_));
+    changed |= comp_changed;
+  }
+  return changed;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h
new file mode 100644
index 0000000..4653286
--- /dev/null
+++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// Performs CSE for all-gather if their users are within reasonable live range.
+class ScheduleAwareAllGatherCSE : public HloModulePass {
+ public:
+  // distance_threshold: maximum live range (in number of HLO instructions on
+  //   the path) to consider CSE.
+  // for_replicas: specifies if this pass is for cross-replica or
+  //   cross-partition all-gathers.
+  explicit ScheduleAwareAllGatherCSE(int64 distance_threshold,
+                                     bool for_replicas)
+      : distance_threshold_(distance_threshold), for_replicas_(for_replicas) {}
+
+  ~ScheduleAwareAllGatherCSE() override = default;
+  absl::string_view name() const override {
+    return "schedule-aware-all-gather-cse";
+  }
+
+  StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+  int64 distance_threshold_;
+  bool for_replicas_;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
index 3c2850c..323bea9 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
@@ -221,15 +221,23 @@
 
 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
-  for (auto& entry : cache) {
-    if (entry.first == target) {
-      return entry.second;
+  const bool is_to_replicate =
+      hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
+  if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
+    for (auto& entry : cache) {
+      if (entry.first == target) {
+        return entry.second;
+      }
     }
   }
-  cache.emplace_back(target, ReshardNoCache(target));
-  state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()]
+  auto resharded = ReshardNoCache(target);
+  state_.reshard_cache->per_hlo_cache[resharded.hlo()]
       .reshard_cache.emplace_back(sharding(), *this);
-  return cache.back().second;
+  if (!is_to_replicate || state_.partitioner->options().cache_all_gather) {
+    cache.emplace_back(target, std::move(resharded));
+    return cache.back().second;
+  }
+  return resharded;
 }
 
 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
@@ -282,133 +290,17 @@
     return ReshardWithAllToAll(target, *src_tgt_dims);
   }
 
-  // Partial replicated to tiled.
-  if (sharding().ReplicateOnLastTileDim() && !target.ReplicateOnLastTileDim() &&
-      !target.IsTileMaximal()) {
-    // Get the temp sharding target from partial replicate to target tile dims.
-    // target_compatible_sharding has the same tile_assignment dimensions
-    // as the target and can reshard to target by collective permute.
-    // target_compatible_sharding could have different device assignment as
-    // targe. sharding() can reshard to target_compatible_sharding by
-    // dynamic slice.
-    auto target_compatible_sharding = PartialReplicateToTileCompatibleSharding(
-        sharding(), target.tile_assignment().dimensions());
-    // Reshard to target_compatible_sharding by dynamic slice.
-    if (target_compatible_sharding.has_value()) {
-      std::vector<int64> expand_tile_dims;
-      std::vector<int64> tiling_dim_factors;
-      int64 rank = shape.rank();
-      tiling_dim_factors.reserve(rank);
-      auto temp_target_sharding = target_compatible_sharding.value();
-      for (int64 dim = 0; dim < rank; dim++) {
-        if (temp_target_sharding.tile_assignment().dim(dim) >
-            sharding().tile_assignment().dim(dim)) {
-          expand_tile_dims.push_back(dim);
-        }
-        tiling_dim_factors.emplace_back(
-            temp_target_sharding.tile_assignment().dim(dim) /
-            sharding().tile_assignment().dim(dim));
-      }
-
-      // Get per_group partitioner state.
-      std::vector<int64> group_dims(
-          sharding().tile_assignment().num_dimensions() - 1);
-      std::iota(group_dims.begin(), group_dims.end(), 0);
-      auto sharding_grouped = GroupShardingOnDims(sharding(), group_dims);
-      auto per_group_partitioner_state = CreatePerGroupPartitioningState(
-          state_, sharding_grouped.device_groups, state_.b);
-      // 2. Get the padded_hlo, do right halo exchange if needed.
-      auto padded_hlo = PadFromPartialReplicateShape(
-          hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
-          state_.collective_ops_creator, state_.next_channel_id,
-          state_.partition_id, state_.b);
-      if (padded_hlo.has_value()) {
-        // 3. Slice out the tile from replicate ones.
-        auto shard_shape =
-            MakePartitionedShape(base_shape_, temp_target_sharding);
-        // device assignment within each group is sorted in
-        // HloSharding::PartialTile, thus partiton_id within each group can be
-        // matched with the order in tile_assignment.
-        Array<int64> tiling_assignment(tiling_dim_factors);
-        tiling_assignment.FillIota(0);
-        auto slice =
-            state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
-                shard_shape, padded_hlo.value(),
-                MakePartitionOffsets(padded_hlo.value()->shape(),
-                                     HloSharding::Tile(tiling_assignment),
-                                     per_group_partitioner_state.partition_id,
-                                     per_group_partitioner_state.b),
-                shard_shape.dimensions()));
-        slice->set_sharding(temp_target_sharding);
-        auto result = PartitionedHlo(slice, base_shape_, state_);
-        // If temp_target_sharding's device assignment is different from target,
-        // use collective permute to reshard.
-        if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
-          return result.ReshardWithCollectivePermute(target);
-        }
-        // If device assignment in temp_target_sharding and target are the same,
-        // return result directly.
-        return result;
-      }
+  if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
+    auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
+    if (try_reshard.has_value()) {
+      return try_reshard.value();
     }
   }
 
-  // Tiled to partial replicate
-  if (!sharding().ReplicateOnLastTileDim() && !sharding().IsTileMaximal() &&
-      target.ReplicateOnLastTileDim()) {
-    // Get the comptible sharding to target with resharding by all reduce.
-    auto compatible_sharding = PartialReplicateToTileCompatibleSharding(
-        target, sharding().tile_assignment().dimensions());
-    if (compatible_sharding.has_value()) {
-      auto temp_sharding = compatible_sharding.value();
-      auto partitioned_hlo = *this;
-      // Use collective permute to adjust device assignment if needed.
-      if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
-        partitioned_hlo =
-            partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
-      }
-
-      // Get replicate dims and replicate factor of each dimensions.
-      int64 rank = hlo_->shape().rank();
-      std::vector<int64> replicate_dims;
-      std::vector<int64> replicate_factors;
-      for (int64 dim = 0; dim < rank; dim++) {
-        int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) /
-                                 target.tile_assignment().dim(dim);
-        if (replicate_factor > 1) {
-          replicate_dims.emplace_back(dim);
-          replicate_factors.emplace_back(replicate_factor);
-        }
-      }
-
-      // Do left halo exchange if all-reduce directly will remove useful data
-      // from the source.
-      auto halo_exchange = TileToPartialReplicateHaloExchange(
-          partitioned_hlo.hlo_, base_shape_, temp_sharding, target,
-          replicate_dims, partitioned_hlo.state().collective_ops_creator,
-          partitioned_hlo.state().next_channel_id,
-          partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
-      if (halo_exchange.has_value()) {
-        auto halo_exchange_hlo = halo_exchange.value();
-        // Grouped on replicate dimensions.
-        auto sharding_grouped = GroupShardingOnDims(
-            temp_sharding, replicate_dims, replicate_factors);
-        auto per_group_partitioner_state = CreatePerGroupPartitioningState(
-            partitioned_hlo.state(), sharding_grouped.device_groups,
-            partitioned_hlo.state().b);
-        auto base_shape = MakePartitionedShape(base_shape_, target);
-        // It's possible that halo_exchange_hlo == hlo.hlo().
-        // Record the sharding of hlo here, and reset it before return.
-        auto original_sharding = partitioned_hlo.sharding();
-        halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
-        auto partial_replicate_hlo = PartitionedHlo(
-            halo_exchange_hlo, base_shape, per_group_partitioner_state);
-        HloInstruction* result =
-            partial_replicate_hlo.ReplicatePartial(replicate_dims);
-        partitioned_hlo.hlo()->set_sharding(original_sharding);
-        result->set_sharding(target);
-        return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
-      }
+  if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
+    auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
+    if (try_reshard.has_value()) {
+      return try_reshard.value();
     }
   }
 
@@ -794,6 +686,14 @@
 }
 
 PartitionedHlo PartitionedHlo::Replicate() {
+  auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
+  if (state_.partitioner->options().cache_all_gather) {
+    for (auto& entry : cache) {
+      if (entry.first.IsReplicated()) {
+        return entry.second;
+      }
+    }
+  }
   const HloSharding& sharding = hlo_->sharding();
   const Shape& shape = hlo_->shape();
   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
@@ -801,7 +701,6 @@
   if (sharding.IsReplicated()) {
     return *this;
   }
-  auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
   for (auto& entry : cache) {
     if (entry.first.IsReplicated()) {
       return entry.second;
@@ -810,8 +709,11 @@
   auto update_cache = [&](PartitionedHlo resharded) {
     state_.reshard_cache->per_hlo_cache[resharded.hlo()]
         .reshard_cache.emplace_back(sharding, *this);
-    cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
-    return cache.back().second;
+    if (state_.partitioner->options().cache_all_gather) {
+      cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
+      return cache.back().second;
+    }
+    return resharded;
   };
   // 'Single Device' to 'Repliated'.
   if (sharding.IsTileMaximal()) {
@@ -872,6 +774,155 @@
   return result;
 }
 
+absl::optional<PartitionedHlo>
+PartitionedHlo::ReshardToPartialReplicateWithAllGather(
+    const HloSharding& target) {
+  if (!target.ReplicateOnLastTileDim()) {
+    return absl::nullopt;
+  }
+  // Tiled/partial replicate to partial replicate
+  // Get the comptible sharding to target with resharding by all reduce.
+  auto compatible_sharding =
+      PartialReplicateReshardCompatibleSharding(target, sharding());
+  if (!compatible_sharding.has_value()) {
+    return absl::nullopt;
+  }
+
+  const auto& temp_sharding = compatible_sharding.value();
+  auto partitioned_hlo = *this;
+  // Use collective permute to adjust device assignment if needed.
+  if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
+    partitioned_hlo =
+        partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
+  }
+
+  // Get replicate dims and replicate factor of each dimensions.
+  int64 rank = hlo_->shape().rank();
+  std::vector<int64> replicate_dims;
+  std::vector<int64> replicate_factors;
+  for (int64 dim = 0; dim < rank; dim++) {
+    int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) /
+                             target.tile_assignment().dim(dim);
+    if (replicate_factor > 1) {
+      replicate_dims.emplace_back(dim);
+      replicate_factors.emplace_back(replicate_factor);
+    }
+  }
+
+  // Do left halo exchange if all-reduce directly will remove useful data
+  // from the source.
+  auto halo_exchange = TileToPartialReplicateHaloExchange(
+      partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
+      partitioned_hlo.state().collective_ops_creator,
+      partitioned_hlo.state().next_channel_id,
+      partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
+  if (!halo_exchange.has_value()) {
+    return absl::nullopt;
+  }
+  auto halo_exchange_hlo = halo_exchange.value();
+  // Grouped on replicate dimensions.
+  auto sharding_grouped =
+      GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors);
+  auto per_group_partitioner_state = CreatePerGroupPartitioningState(
+      partitioned_hlo.state(), sharding_grouped.device_groups,
+      partitioned_hlo.state().b);
+  auto base_shape = MakePartitionedShape(base_shape_, target);
+  // It's possible that halo_exchange_hlo == hlo.hlo().
+  // Record the sharding of hlo here, and reset it before return.
+  auto original_sharding = partitioned_hlo.sharding();
+  halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
+  auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
+                                              per_group_partitioner_state);
+  HloInstruction* result =
+      partial_replicate_hlo.ReplicatePartial(replicate_dims);
+  partitioned_hlo.hlo()->set_sharding(original_sharding);
+  result->set_sharding(target);
+  return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
+}
+
+absl::optional<PartitionedHlo>
+PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
+    const HloSharding& target) {
+  if (!sharding().ReplicateOnLastTileDim()) {
+    return absl::nullopt;
+  }
+
+  // Get the temp sharding target from partial replicate to target tile dims.
+  // target_compatible_sharding has the same tile_assignment dimensions
+  // as the target and can reshard to target by collective permute.
+  // target_compatible_sharding could have different device assignment as
+  // targe. sharding() can reshard to target_compatible_sharding by
+  // dynamic slice.
+  auto target_compatible_sharding =
+      PartialReplicateReshardCompatibleSharding(sharding(), target);
+  // Reshard to target_compatible_sharding by dynamic slice.
+  if (!target_compatible_sharding.has_value()) {
+    return absl::nullopt;
+  }
+  std::vector<int64> expand_tile_dims;
+  std::vector<int64> tiling_dim_factors;
+  int64 rank = hlo_->shape().rank();
+  tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
+  const auto& temp_target_sharding = target_compatible_sharding.value();
+  for (int64 dim = 0; dim < rank; dim++) {
+    if (temp_target_sharding.tile_assignment().dim(dim) >
+        sharding().tile_assignment().dim(dim)) {
+      expand_tile_dims.push_back(dim);
+    }
+    tiling_dim_factors.emplace_back(
+        temp_target_sharding.tile_assignment().dim(dim) /
+        sharding().tile_assignment().dim(dim));
+  }
+
+  // Add another dimension in tiling_dim_factors if target is partial replicate.
+  if (target.ReplicateOnLastTileDim()) {
+    tiling_dim_factors.emplace_back(
+        target.tile_assignment().dimensions().back());
+  }
+
+  // Get per_group partitioner state.
+  std::vector<int64> group_dims(sharding().tile_assignment().num_dimensions() -
+                                1);
+  std::iota(group_dims.begin(), group_dims.end(), 0);
+  auto sharding_grouped = GroupShardingOnDims(sharding(), group_dims);
+  auto per_group_partitioner_state = CreatePerGroupPartitioningState(
+      state_, sharding_grouped.device_groups, state_.b);
+  // 2. Get the padded_hlo, do right halo exchange if needed.
+  auto padded_hlo = PadFromPartialReplicateShape(
+      hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
+      state_.collective_ops_creator, state_.next_channel_id,
+      state_.partition_id, state_.b);
+  if (!padded_hlo.has_value()) {
+    return absl::nullopt;
+  }
+  // 3. Slice out the tile from replicate ones.
+  auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
+  // device assignment within each group is sorted in
+  // HloSharding::PartialTile, thus partiton_id within each group can be
+  // matched with the order in tile_assignment.
+  Array<int64> tiling_assignment(tiling_dim_factors);
+  tiling_assignment.FillIota(0);
+  auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
+      shard_shape, padded_hlo.value(),
+      MakePartitionOffsets(padded_hlo.value()->shape(),
+                           target.ReplicateOnLastTileDim()
+                               ? HloSharding::PartialTile(tiling_assignment)
+                               : HloSharding::Tile(tiling_assignment),
+                           per_group_partitioner_state.partition_id,
+                           per_group_partitioner_state.b),
+      shard_shape.dimensions()));
+  slice->set_sharding(temp_target_sharding);
+  auto result = PartitionedHlo(slice, base_shape_, state_);
+  // If temp_target_sharding's device assignment is different from target,
+  // use collective permute to reshard.
+  if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
+    return result.ReshardWithCollectivePermute(target);
+  }
+  // If device assignment in temp_target_sharding and target are the same,
+  // return result directly.
+  return result;
+}
+
 PartitionedHlo PartitionedHlo::Broadcast() const {
   const Shape& shape = hlo_->shape();
   const HloSharding& sharding = hlo_->sharding();
@@ -1048,6 +1099,25 @@
     const HloSharding& target) const {
   CHECK(CanReshardWithCollectivePermute(sharding(), target))
       << sharding().ToString() << " to " << target.ToString();
+  if (hlo()->opcode() == HloOpcode::kBroadcast) {
+    // If hlo() is a broadcast, check if data is already the same between
+    // source/destination pairs.
+    std::vector<int64> new_dims;
+    for (int64 i = 0; i < hlo()->shape().rank(); ++i) {
+      if (!absl::c_linear_search(hlo()->dimensions(), i)) {
+        new_dims.push_back(i);
+      }
+    }
+    if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(sharding(),
+                                                                 new_dims) ==
+        hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(target,
+                                                                 new_dims)) {
+      auto copy = state_.b->AddInstruction(
+          HloInstruction::CreateUnary(hlo()->shape(), HloOpcode::kCopy, hlo()));
+      copy->set_sharding(target);
+      return PartitionedHlo(copy, base_shape_, state_);
+    }
+  }
   std::vector<std::pair<int64, int64>> src_dst_pairs;
   sharding().tile_assignment().Each(
       [&](absl::Span<const int64> indices, int64 src_device) {
@@ -1219,7 +1289,7 @@
 // gather/scatter slice size 1.
 bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
     const PartitionedHlo& operand, absl::Span<const int64> index_map,
-    absl::Span<const int64> slice_size, int64 num_partitions) {
+    absl::Span<const int64> slice_size) {
   if (operand.sharding().IsTileMaximal()) {
     return false;
   }
@@ -1230,7 +1300,7 @@
           operand.sharding().tile_assignment().dim(dim);
     }
   }
-  return trivial_slice_dims_partitions == num_partitions;
+  return trivial_slice_dims_partitions == operand.sharding().NumTiles();
 }
 
 // Returns the min and max for the indices (replicated) in a scatter/gather
@@ -1381,10 +1451,23 @@
               update_dim_to_index_dim);
       CHECK(new_updates_sharding.has_value());
       updates = updates.Reshard(*new_updates_sharding);
+      // Update collective_ops_creator and partition_id for partial replicate.
+      auto collective_ops_creator = collective_ops_creator_;
+      auto partition_id = partition_id_;
+      if (indices.sharding().ReplicateOnLastTileDim()) {
+        auto sharding_grouped = GroupShardingOnDims(
+            indices.sharding(),
+            {indices.sharding().tile_assignment().num_dimensions() - 1});
+        auto per_group_partitioner_state = CreatePerGroupPartitioningState(
+            indices.state(), sharding_grouped.device_groups, &b_);
+        collective_ops_creator =
+            per_group_partitioner_state.collective_ops_creator;
+        partition_id = per_group_partitioner_state.partition_id;
+      }
       // To avoid accumulating the initial operand multiple times during
       // all-reduce, we use identity operands for all non-zero partitions.
       auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert(
-          ShapeUtil::MakeScalarShape(PRED), partition_id_));
+          ShapeUtil::MakeScalarShape(PRED), partition_id));
       not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast(
           ShapeUtil::ChangeElementType(identity->shape(), PRED),
           not_partition_zero, {}));
@@ -1395,7 +1478,7 @@
       auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands(
           scatter->shape(), {select_operand, indices.hlo(), updates.hlo()}));
       auto all_reduce =
-          collective_ops_creator_.create_cross_partition_all_reduce(
+          collective_ops_creator.create_cross_partition_all_reduce(
               &b_, pscatter, scatter->to_apply(), {}, NewChannel());
       all_reduce->set_sharding(HloSharding::Replicate());
       SetPartitionedHlo(hlo, [&]() {
@@ -1425,8 +1508,7 @@
       return Status::OK();
     }
     if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
-            operand, scatter_dims_to_operand_dims, slice_size,
-            num_partitions_) &&
+            operand, scatter_dims_to_operand_dims, slice_size) &&
         ShapeSizeInBytes(updates.base_shape()) <
             ShapeSizeInBytes(scatter->shape())) {
       // Operand is sharded on trivial slice dims (update slice size 1). We can
@@ -2301,8 +2383,7 @@
       return Status::OK();
     }
     if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
-            operand, start_index_map, gather->gather_slice_sizes(),
-            num_partitions_) &&
+            operand, start_index_map, gather->gather_slice_sizes()) &&
         ShapeSizeInBytes(gather->shape()) <
             ShapeSizeInBytes(gather->operand(0)->shape())) {
       indices = indices.Reshard(HloSharding::Replicate());
@@ -2364,7 +2445,17 @@
           pgather->shape(), HloOpcode::kSelect, broadcast_filter,
           CreateZero(pgather->shape(), &b_), pgather));
       // Combine from different partitions.
-      auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
+      auto collective_ops_creator = collective_ops_creator_;
+      if (operand.sharding().ReplicateOnLastTileDim()) {
+        auto sharding_grouped = GroupShardingOnDims(
+            operand.sharding(),
+            {operand.sharding().tile_assignment().num_dimensions() - 1});
+        auto per_group_partitioner_state = CreatePerGroupPartitioningState(
+            operand.state(), sharding_grouped.device_groups, &b_);
+        collective_ops_creator =
+            per_group_partitioner_state.collective_ops_creator;
+      }
+      auto ar = collective_ops_creator.create_cross_partition_all_reduce(
           &b_, filtered,
           MakeBinaryAdd(filtered->shape().element_type(), module_), {},
           NewChannel());
@@ -3370,7 +3461,7 @@
     HloPassPipeline pass("spmd-cleanup");
     pass.AddPass<TupleSimplifier>();
     pass.AddPass<HloDCE>();
-    pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+    pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
     pass.AddPass<FlattenCallGraph>();
     TF_RETURN_IF_ERROR(pass.Run(module).status());
   }
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
index a612c16..6447d08 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
@@ -47,6 +47,12 @@
 
   // Whether the entry computations' signature could change after partitioning.
   bool allow_module_signature_change = false;
+
+  // Whether to use cached all-gather to avoid repeatedly replicate a tiled
+  // tensor. If it is set to false, the result tends to be more
+  // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE
+  // pass to CSE some all-gathers which are relatively close to each other.
+  bool cache_all_gather = true;
 };
 
 // Class to wrap the computation builder to capture information during SPMD
@@ -180,6 +186,8 @@
       int64 channel_id, absl::Span<const int64> selected_dims,
       const SPMDCollectiveOpsCreator& collectives_creator);
 
+  const SpmdPartitionerOptions& options() { return options_; }
+
  protected:
   virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
       HloComputation* computation, int64 num_partitions, int64 num_replicas,
@@ -305,6 +313,14 @@
   // Helper function to reshard the tensor using CollectivePermute.
   PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
 
+  // Helper function to reshard to partial replicate using AllGather.
+  absl::optional<PartitionedHlo> ReshardToPartialReplicateWithAllGather(
+      const HloSharding& target);
+
+  // Helper function to reshard from partial replicate using DynamicSlice.
+  absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
+      const HloSharding& target);
+
   // SPMD instruction.
   HloInstruction* hlo_;
 
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index e2826b2..8caaba8 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -138,8 +138,7 @@
               op::AllReduce(op::Select(
                   op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
                   op::Constant(), op::Broadcast())),
-              op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
-                                           op::Constant())),
+              op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
               op::Constant())),
           op::Shape("s32[1,3]")));
 }
@@ -161,8 +160,7 @@
       op::Copy(op::AllReduce(AllOf(
           op::DynamicUpdateSlice(
               op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
-              op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
-                                           op::Constant())),
+              op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
               op::Constant()),
           op::Shape("s32[2,3]")))));
 }
@@ -184,8 +182,7 @@
       op::Copy(op::Copy(op::AllReduce(AllOf(
           op::DynamicUpdateSlice(
               op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
-              op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
-                                           op::Constant())),
+              op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
               op::Constant()),
           op::Shape("s32[2,3]"))))));
 }
@@ -279,8 +276,8 @@
   HloInstruction* root = module->entry_computation()->root_instruction();
   ASSERT_THAT(root, op::Tuple());
 
-  auto offset = op::Reshape(
-      op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant()));
+  auto offset =
+      op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
 
   EXPECT_THAT(root->operand(0),
               op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
@@ -305,13 +302,13 @@
                           PartitionComputation(hlo_string, /*num_devices=*/2));
   HloInstruction* root = module->entry_computation()->root_instruction();
   EXPECT_THAT(
-      root, op::Copy(op::AllReduce(op::DynamicUpdateSlice(
-                op::Broadcast(),
-                op::GetTupleElement(
-                    AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))),
-                op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(),
-                                             op::Constant())),
-                op::Constant()))));
+      root,
+      op::Copy(op::AllReduce(op::DynamicUpdateSlice(
+          op::Broadcast(),
+          op::GetTupleElement(
+              AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))),
+          op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
+          op::Constant()))));
 }
 
 TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) {
@@ -3923,6 +3920,26 @@
                           op::Shape("f32[3,5]")));
 }
 
+TEST_F(SpmdPartitioningTest, PassthroughGather_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0),
+    sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+  %indices = s32[3] parameter(1), sharding={replicated}
+  ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
+    collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
+    slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
+                          op::Shape("f32[3,5]")));
+}
+
 TEST_F(SpmdPartitioningTest, IndexPassthroughGather) {
   const char* const hlo_string = R"(
 HloModule module
@@ -3942,6 +3959,27 @@
                           op::Shape("f32[8,2,2]")));
 }
 
+TEST_F(SpmdPartitioningTest, IndexPassthroughGather_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = f32[2,9,8] parameter(0), sharding={replicated}
+  %indices = s32[4,2,4] parameter(1),
+    sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
+    collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
+    slice_sizes={1,1,8},
+    sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+  VLOG(1) << module->ToString();
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
+                          op::Shape("f32[8,2,2]")));
+}
+
 TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) {
   const char* const hlo_string = R"(
 HloModule module
@@ -3956,8 +3994,39 @@
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           PartitionComputation(hlo_string, /*num_devices=*/2));
   VLOG(1) << module->ToString();
-  auto offset = op::Reshape(
-      op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant()));
+  auto offset =
+      op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
+  auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
+  auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
+                   op::Shape("s32[2,3]"));
+  auto clamp = op::Clamp(min, op::Parameter(1), max);
+  auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
+  auto mask =
+      op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
+  auto masked =
+      op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
+}
+
+TEST_F(SpmdPartitioningTest,
+       GatherPartitionedOnTrivialSliceDims_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = f32[17,9] parameter(0),
+    sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
+  %indices = s32[2,3] parameter(1), sharding={replicated}
+  ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
+    collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
+    slice_sizes={1,9}, sharding={replicated}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+  auto offset =
+      op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
   auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
   auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
                    op::Shape("s32[2,3]"));
@@ -4001,6 +4070,39 @@
                           op::Shape("f32[2,5]")));
 }
 
+TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[2,9] parameter(0),
+    sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+  %indices = s32[3] parameter(1), sharding={replicated}
+  %updates = f32[3,9] parameter(2),
+    sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+  ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={1},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=1,
+      sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
+                                      op::Parameter(2)),
+                          op::Shape("f32[2,5]")));
+}
+
 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) {
   const char* const hlo_string = R"(
 HloModule module
@@ -4035,6 +4137,42 @@
             op::Shape("f32[2,9,8]")));
 }
 
+TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[2,9,8] parameter(0), sharding={replicated}
+  %indices = s32[4,2,4] parameter(1),
+    sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  %updates = f32[4,4,8] parameter(2),
+    sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={2},
+      inserted_window_dims={0,1},
+      scatter_dims_to_operand_dims={0,1},
+      index_vector_dim=1, sharding={replicated}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+  VLOG(1) << module->ToString();
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(
+      root,
+      AllOf(op::AllReduce(op::Scatter(
+                op::Select(op::Broadcast(op::Convert(op::Reshape())),
+                           op::Broadcast(op::Constant()), op::Parameter(0)),
+                op::Parameter(1), op::Parameter(2))),
+            op::Shape("f32[2,9,8]")));
+}
+
 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) {
   const char* const hlo_string = R"(
 HloModule module
@@ -4093,8 +4231,45 @@
   TF_ASSERT_OK_AND_ASSIGN(auto module,
                           PartitionComputation(hlo_string, /*num_devices=*/2));
   VLOG(1) << module->ToString();
-  auto offset = op::Reshape(
-      op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant()));
+  auto offset =
+      op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
+  auto indices = op::Subtract(
+      op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root,
+              AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
+                    op::Shape("f32[9,9]")));
+}
+
+TEST_F(SpmdPartitioningTest,
+       ScatterPartitionedOnTrivialSliceDims_PartialReplicate) {
+  const char* const hlo_string = R"(
+HloModule module
+
+add (lhs: f32[], rhs: f32[]) -> f32[] {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT sum = f32[] add(lhs, rhs)
+}
+
+ENTRY entry {
+  %input = f32[17,9] parameter(0),
+    sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
+  %indices = s32[2,3] parameter(1), sharding={replicated}
+  %updates = f32[2,3,9] parameter(2), sharding={replicated}
+  ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
+      to_apply=add,
+      update_window_dims={2},
+      inserted_window_dims={0},
+      scatter_dims_to_operand_dims={0},
+      index_vector_dim=2,
+      sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+  auto offset =
+      op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
   auto indices = op::Subtract(
       op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
   HloInstruction* root = module->entry_computation()->root_instruction();
@@ -4834,6 +5009,266 @@
   EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs)));
 }
 
+TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshard) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0)
+  %copy = f32[8,8] copy(%param0),
+    sharding={devices=[2,2]0,1,2,3}
+  ROOT %copy0 = f32[8,8] copy(%copy),
+    sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+  auto tiled = AllOf(op::Shape("f32[4,4]"),
+                     op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
+                                               op::Reshape())));
+  auto partially_replicated = AllOf(
+      op::Shape("f32[4,8]"), op::Copy(op::AllReduce(op::DynamicUpdateSlice(
+                                 op::Broadcast(_), tiled, _, _))));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, partially_replicated);
+}
+
+TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0)
+  %copy = f32[8,8] copy(%param0),
+    sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
+  ROOT %copy0 = f32[8,8] copy(%copy),
+    sharding={devices=[2,2]0,1,2,3}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/4));
+  VLOG(1) << module->ToString();
+  auto partially_replicated =
+      AllOf(op::Shape("f32[4,8]"),
+            op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
+                                      op::Constant())));
+  auto tiled = AllOf(op::Shape("f32[4,4]"),
+                     op::Copy(op::DynamicSlice(partially_replicated,
+                                               op::Constant(), op::Reshape())));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, tiled);
+}
+
+TEST_F(SpmdPartitioningTest,
+       PartialReplicateToPartialReplicateReshard_AllReduce) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0)
+  %copy = f32[8,8] copy(param0),
+    sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %copy0 = f32[8,8] copy(%copy),
+    sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+
+  VLOG(1) << module->ToString();
+  auto partially_replicated_init =
+      AllOf(op::Shape("f32[4,4]"),
+            op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
+                                      op::Reshape())));
+  auto partially_replicated =
+      AllOf(op::Shape("f32[4,8]"),
+            op::Copy(op::AllReduce(op::DynamicUpdateSlice(
+                op::Broadcast(_), partially_replicated_init, _, _))));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, partially_replicated);
+}
+
+TEST_F(SpmdPartitioningTest,
+       PartialReplicateToPartialReplicateReshard_DynamicSlice) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0)
+  %copy = f32[8,8] copy(%param0),
+    sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %copy0 = f32[8,8] copy(%copy),
+    sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+  VLOG(1) << module->ToString();
+  auto partially_replicated =
+      AllOf(op::Shape("f32[4,8]"),
+            op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
+                                      op::Constant())));
+  auto tiled = AllOf(op::Shape("f32[4,4]"),
+                     op::Copy(op::DynamicSlice(partially_replicated,
+                                               op::Constant(), op::Reshape())));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, tiled);
+}
+
+TEST_F(SpmdPartitioningTest,
+       PartialReplicateToPartialReplicateReshard_DynamicSlice2) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0)
+  %copy = f32[8,8] copy(%param0),
+    sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %copy0 = f32[8,8] copy(%copy),
+    sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+  VLOG(1) << module->ToString();
+  auto partially_replicated =
+      AllOf(op::Shape("f32[8,8]"),
+            op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
+                                      op::Constant())));
+  auto tiled = AllOf(op::Shape("f32[4,4]"),
+                     op::Copy(op::DynamicSlice(partially_replicated,
+                                               op::Reshape(), op::Reshape())));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, tiled);
+}
+
+TEST_F(SpmdPartitioningTest,
+       PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0)
+  %copy = f32[8,8] copy(param0),
+    sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %copy0 = f32[8,8] copy(%copy),
+    sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+
+  VLOG(1) << module->ToString();
+  auto partially_replicated_init =
+      AllOf(op::Shape("f32[4,4]"),
+            op::CollectivePermute(op::Copy(op::DynamicSlice(
+                op::Parameter(0), op::Reshape(), op::Reshape()))));
+  auto partially_replicated =
+      AllOf(op::Shape("f32[8,4]"),
+            op::Copy(op::AllReduce(op::DynamicUpdateSlice(
+                op::Broadcast(_), partially_replicated_init, _, _))));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, partially_replicated);
+}
+
+TEST_F(SpmdPartitioningTest,
+       PartialReplicateToPartialReplicateReshardCollectivePermute1) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[8,8] parameter(0)
+  %copy = f32[8,8] copy(%param0),
+    sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %copy0 = f32[8,8] copy(%copy),
+    sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+  VLOG(1) << module->ToString();
+  auto partially_replicated =
+      AllOf(op::Shape("f32[8,4]"),
+            op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
+                                      op::Reshape())));
+  auto tiled =
+      AllOf(op::Shape("f32[4,4]"),
+            op::Copy(op::CollectivePermute(op::DynamicSlice(
+                partially_replicated, op::Reshape(), op::Constant()))));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, tiled);
+}
+
+TEST_F(SpmdPartitioningTest,
+       PartialReplicateToPartialReplicateReshardHaloExchange) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[6,3] parameter(0)
+  %copy = f32[6,3] copy(param0),
+    sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %copy0 = f32[6,3] copy(%copy),
+    sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+
+  VLOG(1) << module->ToString();
+  auto partially_replicated_init =
+      AllOf(op::Shape("f32[2,3]"),
+            op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
+                                      op::Reshape(), op::Constant())));
+  auto slice =
+      AllOf(op::Shape("f32[2,3]"),
+            op::DynamicSlice(op::Concatenate(op::CollectivePermute(op::Slice(
+                                                 partially_replicated_init)),
+                                             partially_replicated_init),
+                             _, _));
+  auto partially_replicated =
+      AllOf(op::Shape("f32[3,3]"),
+            op::Copy(op::Slice(op::AllReduce(
+                op::DynamicUpdateSlice(op::Broadcast(_), slice, _, _)))));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, partially_replicated);
+}
+
+TEST_F(SpmdPartitioningTest,
+       PartialReplicateToPartialReplicateReshardHaloExchange1) {
+  const char* const hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %param0 = f32[6,3] parameter(0)
+  %copy = f32[6,3] copy(param0),
+    sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+  ROOT %copy0 = f32[6,3] copy(%copy),
+    sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          PartitionComputation(hlo_string, /*num_devices=*/8));
+
+  VLOG(1) << module->ToString();
+  auto partially_replicated_init =
+      AllOf(op::Shape("f32[3,3]"),
+            op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
+                                      op::Constant())));
+  auto slice = AllOf(
+      op::Shape("f32[4,3]"),
+      op::DynamicSlice(op::Pad(op::Concatenate(partially_replicated_init,
+                                               op::CollectivePermute(op::Slice(
+                                                   partially_replicated_init))),
+                               op::Constant()),
+                       _, _));
+  auto partially_replicated =
+      AllOf(op::Shape("f32[2,3]"), op::Copy(op::DynamicSlice(slice, _, _)));
+  auto root = module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, partially_replicated);
+}
+
 }  // namespace
 }  // namespace spmd
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
index f20a26e..0edbd4f 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
@@ -29,6 +29,7 @@
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -202,13 +203,17 @@
     absl::Span<const int64> dims) {
   CHECK(!shape.IsTuple());
 
-  Array2D<int32> offset_array(
-      {sharding.tile_assignment().num_elements(), shape.rank()});
-  offset_array.Each([&](int64 i, int64 j, int32* value) {
-    *value = sharding.TileOffsetForDevice(shape, i)[j];
-  });
-  auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR2FromArray2D(offset_array)));
+  std::vector<std::vector<int32>> offset_arrays(shape.rank());
+  for (int64 i = 0; i < shape.rank(); ++i) {
+    offset_arrays[i].resize(sharding.tile_assignment().num_elements());
+  }
+  auto shard_shape = MakePartitionedShape(shape, sharding);
+  sharding.tile_assignment().Each(
+      [&](absl::Span<const int64> indices, int64 device) {
+        for (int64 i = 0; i < shape.rank(); ++i) {
+          offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i);
+        }
+      });
   std::vector<HloInstruction*> offsets;
   for (int64 i = 0; i < shape.rank(); ++i) {
     if (sharding.tile_assignment().dim(i) == 1 ||
@@ -216,11 +221,10 @@
       offsets.push_back(b->AddInstruction(
           HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
     } else {
+      auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR1<int32>(offset_arrays[i])));
       auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
-          ShapeUtil::MakeShape(S32, {1, 1}), offset_table,
-          {partition_id, b->AddInstruction(HloInstruction::CreateConstant(
-                             LiteralUtil::CreateR0<uint32>(i)))},
-          {1, 1}));
+          ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1}));
       offsets.push_back(b->AddInstruction(
           HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
     }
@@ -292,17 +296,29 @@
   return PadToShape(hlo, padded_base_shape, b);
 }
 
-// TODO(wangtao): generize this function when target is partial replicate.
-absl::optional<HloSharding> PartialReplicateToTileCompatibleSharding(
-    const HloSharding& partial_sharding,
-    const std::vector<int64>& target_tile_dims) {
+absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
+    const HloSharding& partial_sharding, const HloSharding& target_sharding) {
   if (!partial_sharding.ReplicateOnLastTileDim()) {
     return absl::nullopt;
   }
   int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1;
-  if (target_tile_dims.size() < rank) {
+  int64 target_rank = target_sharding.tile_assignment().num_dimensions() -
+                      (target_sharding.ReplicateOnLastTileDim() ? 1 : 0);
+  if (target_rank != rank) {
     return absl::nullopt;
   }
+
+  absl::flat_hash_map<int64, int64> device_to_replication_group;
+  partial_sharding.tile_assignment().Each(
+      [&](absl::Span<const int64> indices, int64 device) {
+        int64 gid = 0;
+        for (int64 i = 0; i < rank; ++i) {
+          gid *= partial_sharding.tile_assignment().dim(i);
+          gid += indices[i];
+        }
+        device_to_replication_group[device] = gid;
+      });
+
   // A dimension is expanded when target_tile_size > partial_tile_size and
   // target_tile_size % partial_tile_size == 0.
   // expand_tile_dims_positions is the index of the expand_dim.
@@ -312,7 +328,7 @@
   int num_expand_dims = 0;
   for (int64 dim = 0; dim < rank; dim++) {
     int64 partial_tile_size = partial_sharding.tile_assignment().dim(dim);
-    int64 target_tile_size = target_tile_dims[dim];
+    int64 target_tile_size = target_sharding.tile_assignment().dim(dim);
     if (target_tile_size % partial_tile_size != 0 ||
         target_tile_size < partial_tile_size) {
       return absl::nullopt;
@@ -325,14 +341,26 @@
   }
 
   // Reshape the partial replicate tile_dimensions.
+  int64 num_target_replication = 1;
+  if (target_sharding.ReplicateOnLastTileDim()) {
+    num_target_replication =
+        target_sharding.tile_assignment().dimensions().back();
+  }
   auto reshape_dimensions = partial_sharding.tile_assignment().dimensions();
   int64 num_replication = reshape_dimensions.back();
-  if (num_replication != Product(expand_tile_sizes)) {
+  if (num_replication / num_target_replication != Product(expand_tile_sizes) ||
+      num_replication % num_target_replication != 0) {
     return absl::nullopt;
   }
+
   reshape_dimensions.pop_back();
   reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
                             expand_tile_sizes.end());
+
+  if (target_sharding.ReplicateOnLastTileDim()) {
+    reshape_dimensions.push_back(num_target_replication);
+  }
+
   auto reshape_tile_assignment = partial_sharding.tile_assignment();
   reshape_tile_assignment.Reshape(reshape_dimensions);
 
@@ -346,13 +374,31 @@
     }
   }
   auto transpose_sharding = hlo_sharding_util::TransposeSharding(
-      HloSharding::Tile(reshape_tile_assignment), perm);
+      target_sharding.ReplicateOnLastTileDim()
+          ? HloSharding::PartialTile(reshape_tile_assignment)
+          : HloSharding::Tile(reshape_tile_assignment),
+      perm);
 
   // Reshape to target shape
   auto transpose_tile_assignment = transpose_sharding.tile_assignment();
-  transpose_tile_assignment.Reshape(target_tile_dims);
+  transpose_tile_assignment.Reshape(
+      target_sharding.tile_assignment().dimensions());
 
-  return HloSharding::Tile(transpose_tile_assignment);
+  bool groups_matching = true;
+  target_sharding.tile_assignment().Each(
+      [&](absl::Span<const int64> indices, int64 device) {
+        if (device_to_replication_group[device] !=
+            device_to_replication_group[transpose_tile_assignment(indices)]) {
+          groups_matching = false;
+        }
+      });
+
+  if (groups_matching) {
+    return target_sharding;
+  }
+  return target_sharding.ReplicateOnLastTileDim()
+             ? HloSharding::PartialTile(transpose_tile_assignment)
+             : HloSharding::Tile(transpose_tile_assignment);
 }
 
 absl::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
index 2d3bf3a..4fc193d 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
@@ -356,19 +356,19 @@
     const SPMDCollectiveOpsCreator& collective_ops_creator,
     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
 
-// Get the compatible sharding from a partial replicate sharding to a given
-// target tile dimensions.
+// Get the compatible sharding from a partial replicate sharding to a desired
+// target tiled sharding.
 // Compatible means replicate sharding can transform to the target tile
 // dimensions by dynamic slice.
 // For example, if partial_sharding is
 // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
-// Target tile dims is {2, 2}, the returned compatible sharding will be
-// sharding={devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}.
+// Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding
+// will be sharding={devices=[2,2]0,2,1,3}.
 // If patial replicate sharding is not partial replicate or can't reshard to
 // target_tile_dims by dynamic slice, return absl::nullopt.
-absl::optional<HloSharding> PartialReplicateToTileCompatibleSharding(
-    const HloSharding& partial_sharding,
-    const std::vector<int64>& target_tile_dims);
+// If target_sharding is already compatible, returns it.
+absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
+    const HloSharding& partial_sharding, const HloSharding& target_sharding);
 
 // Do left halo exchange if all-reduce directly from tile sharding to partial
 // replicate sharding will remove useful data from the source.
diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc
index d54eb9e..4015c69 100644
--- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc
+++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc
@@ -89,16 +89,23 @@
     // The last block might be smaller than the block size,
     // so we will need to pad it
     if (n % block_size != 0) {
-      // Pad with zeros
+      // Pad with identity matrix.
       auto last_blocks =
           SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
       PaddingConfig config = MakeNoPaddingConfig(ndims);
       int64 padding = block_size - n % block_size;
-      config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding);
       config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
       last_blocks =
           Pad(last_blocks, Zero(builder, shape.element_type()), config);
 
+      auto eye =
+          IdentityMatrix(builder, shape.element_type(), padding, padding);
+      config = MakeNoPaddingConfig(ndims);
+      config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n %
+                                                                 block_size);
+      eye = Pad(eye, Zero(builder, shape.element_type()), config);
+      last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
+
       // Add a singleton dimension
       // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
       TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
@@ -121,134 +128,6 @@
   });
 }
 
-XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a,
-                           bool conjugate_a,
-                           PrecisionConfig::Precision precision) {
-  XlaBuilder* builder = diag_blocks.builder();
-  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    // Input is a batch of square lower triangular square matrices. Its shape is
-    // (..., size, size). We resize this to (num_blocks, size, size).
-    TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
-    int64 block_size = ShapeUtil::GetDimension(shape, -1);
-    int64 num_blocks = ShapeUtil::ElementsIn(shape) /
-                       tensorflow::MathUtil::IPow(block_size, 2);
-    diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
-
-    // The input must be triangular because we rely on that when doing
-    // multiplications later on
-    diag_blocks = Triangle(diag_blocks, /*lower=*/lower);
-
-    // Rescale blocks to be unit triangular, but avoid dividing by
-    // zero (which can happen if the last block was padded) otherwise it will
-    // introduce nans which will propagate
-    auto diags = GetMatrixDiagonal(diag_blocks);
-    auto ones = FullLike(diags, 1);
-    diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
-    auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
-
-    // We can now use the fact that for an upper triangular matrix
-    // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
-    // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
-    // have been rescaled to be unit triangular, so L22 = L22' = 1.
-
-    // Initialize the output matrix with -1s on the diagonal. We use -1 instead
-    // of 1 because we cannot do matrix-vector multiplies with variable shapes
-    // inside of a loop, or do irregularly shaped in-place updates. Hence,
-    // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
-    // entire row i.e. we calculate
-    // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
-    // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
-    auto identity =
-        IdentityMatrix(builder, shape.element_type(), block_size, block_size);
-    auto neg_identity = -identity;
-
-    // The first or last  diagonal element should be set to 1 instead of -1
-    // though, since we never update it
-    auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
-    auto start_index = ConstantR0<int>(builder, (lower) ? 0 : block_size - 1);
-    auto output_block =
-        DynamicUpdateSlice(neg_identity, pos_one,
-                           /*start_indices=*/{start_index, start_index});
-
-    // Broadcast diag([1, -1, -1, ...]) to every block
-    XlaOp output = Broadcast(output_block,
-                             /*broadcast_sizes=*/{num_blocks});
-
-    // Now we construct a loop that performs matrix-vector multiplications
-    // inverting the blocks one row at a time
-    std::vector<Shape> tuple_shapes = {
-        // The loop iteration counter is a scalar, incremented each iteration.
-        ShapeUtil::MakeShape(S32, {}),
-        // The output has the shape of A, with one row updated each iteration.
-        ShapeUtil::MakeShape(shape.element_type(),
-                             {num_blocks, block_size, block_size}),
-        // The input is a loop invariant.
-        ShapeUtil::MakeShape(shape.element_type(),
-                             {num_blocks, block_size, block_size})};
-    Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
-
-    auto init_i = One(builder, S32);
-    auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
-
-    // Construct the loop condition function.
-    std::unique_ptr<XlaBuilder> condb =
-        builder->CreateSubBuilder("InvertDiagCond");
-    {
-      auto i = GetTupleElement(
-          Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
-      Lt(i, ConstantR0<int32>(condb.get(), block_size));
-    }
-    TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
-
-    // Construct the loop body function.
-    std::unique_ptr<XlaBuilder> bodyb =
-        builder->CreateSubBuilder("InvertDiagBody");
-    {
-      auto input_tuple =
-          Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
-
-      auto i = GetTupleElement(input_tuple, 0);
-      auto body_out = GetTupleElement(input_tuple, 1);
-      auto body_input = GetTupleElement(input_tuple, 2);
-
-      auto zero = ConstantR0<int32>(bodyb.get(), 0);
-      auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i;
-      auto input_row =
-          DynamicSlice(body_input, {zero, j, zero},
-                       /*slice_sizes=*/{num_blocks, 1, block_size});
-
-      // We want -L21 L11^{-1}
-      DotDimensionNumbers dnums;
-      dnums.add_lhs_batch_dimensions(0);
-      dnums.add_rhs_batch_dimensions(0);
-      dnums.add_lhs_contracting_dimensions(2);
-      dnums.add_rhs_contracting_dimensions(1);
-      PrecisionConfig precision_proto;
-      precision_proto.add_operand_precision(precision);
-      precision_proto.add_operand_precision(precision);
-      auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
-
-      body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
-
-      auto next_i = i + ScalarLike(i, 1);
-      Tuple(bodyb.get(), {next_i, body_out, body_input});
-    }
-    TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
-
-    // Construct the While loop and return the result,
-    // return while_loop(cond_fun, body_fun, init)[1]
-    auto invert_while = While(cond, body, init);
-    auto inv_diag_blocks = GetTupleElement(invert_while, 1);
-
-    // Undo the scaling
-    inv_diag_blocks = Div(inv_diag_blocks, diags,
-                          /*broadcast_dimensions=*/{0, 1});
-
-    // Reshape back to original batch major dimensions
-    return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
-  });
-}
-
 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
                                       bool left_side, bool lower,
                                       bool transpose_a, bool conjugate_a,
@@ -357,10 +236,140 @@
   });
 }
 
-XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
-                           bool transpose_a, bool conjugate_a,
-                           bool unit_diagonal, int64 block_size,
-                           PrecisionConfig::Precision precision) {
+}  // namespace
+
+XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
+    XlaOp diag_blocks, bool lower_triangular,
+    PrecisionConfig::Precision precision) {
+  XlaBuilder* builder = diag_blocks.builder();
+  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    // Input is a batch of square lower triangular square matrices. Its shape is
+    // (..., size, size). We resize this to (num_blocks, size, size).
+    TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
+    int64 block_size = ShapeUtil::GetDimension(shape, -1);
+    int64 num_blocks = ShapeUtil::ElementsIn(shape) /
+                       tensorflow::MathUtil::IPow(block_size, 2);
+    diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
+
+    // The input must be triangular because we rely on that when doing
+    // multiplications later on
+    diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
+
+    // Rescale blocks to be unit triangular, but avoid dividing by
+    // zero (which can happen if the last block was padded) otherwise it will
+    // introduce nans which will propagate
+    auto diags = GetMatrixDiagonal(diag_blocks);
+    auto ones = FullLike(diags, 1);
+    diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
+    auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
+
+    // We can now use the fact that for an upper triangular matrix
+    // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
+    // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
+    // have been rescaled to be unit triangular, so L22 = L22' = 1.
+
+    // Initialize the output matrix with -1s on the diagonal. We use -1 instead
+    // of 1 because we cannot do matrix-vector multiplies with variable shapes
+    // inside of a loop, or do irregularly shaped in-place updates. Hence,
+    // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
+    // entire row i.e. we calculate
+    // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
+    // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
+    auto identity =
+        IdentityMatrix(builder, shape.element_type(), block_size, block_size);
+    auto neg_identity = -identity;
+
+    // The first or last  diagonal element should be set to 1 instead of -1
+    // though, since we never update it
+    auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
+    auto start_index =
+        ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
+    auto output_block =
+        DynamicUpdateSlice(neg_identity, pos_one,
+                           /*start_indices=*/{start_index, start_index});
+
+    // Broadcast diag([1, -1, -1, ...]) to every block
+    XlaOp output = Broadcast(output_block,
+                             /*broadcast_sizes=*/{num_blocks});
+
+    // Now we construct a loop that performs matrix-vector multiplications
+    // inverting the blocks one row at a time
+    std::vector<Shape> tuple_shapes = {
+        // The loop iteration counter is a scalar, incremented each iteration.
+        ShapeUtil::MakeShape(S32, {}),
+        // The output has the shape of A, with one row updated each iteration.
+        ShapeUtil::MakeShape(shape.element_type(),
+                             {num_blocks, block_size, block_size}),
+        // The input is a loop invariant.
+        ShapeUtil::MakeShape(shape.element_type(),
+                             {num_blocks, block_size, block_size})};
+    Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
+
+    auto init_i = One(builder, S32);
+    auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
+
+    // Construct the loop condition function.
+    std::unique_ptr<XlaBuilder> condb =
+        builder->CreateSubBuilder("InvertDiagCond");
+    {
+      auto i = GetTupleElement(
+          Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
+      Lt(i, ConstantR0<int32>(condb.get(), block_size));
+    }
+    TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
+
+    // Construct the loop body function.
+    std::unique_ptr<XlaBuilder> bodyb =
+        builder->CreateSubBuilder("InvertDiagBody");
+    {
+      auto input_tuple =
+          Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
+
+      auto i = GetTupleElement(input_tuple, 0);
+      auto body_out = GetTupleElement(input_tuple, 1);
+      auto body_input = GetTupleElement(input_tuple, 2);
+
+      auto zero = ConstantR0<int32>(bodyb.get(), 0);
+      auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
+      auto input_row =
+          DynamicSlice(body_input, {zero, j, zero},
+                       /*slice_sizes=*/{num_blocks, 1, block_size});
+
+      // We want -L21 L11^{-1}
+      DotDimensionNumbers dnums;
+      dnums.add_lhs_batch_dimensions(0);
+      dnums.add_rhs_batch_dimensions(0);
+      dnums.add_lhs_contracting_dimensions(2);
+      dnums.add_rhs_contracting_dimensions(1);
+      PrecisionConfig precision_proto;
+      precision_proto.add_operand_precision(precision);
+      precision_proto.add_operand_precision(precision);
+      auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
+
+      body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
+
+      auto next_i = i + ScalarLike(i, 1);
+      Tuple(bodyb.get(), {next_i, body_out, body_input});
+    }
+    TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
+
+    // Construct the While loop and return the result,
+    // return while_loop(cond_fun, body_fun, init)[1]
+    auto invert_while = While(cond, body, init);
+    auto inv_diag_blocks = GetTupleElement(invert_while, 1);
+    // Undo the scaling
+    inv_diag_blocks = Div(inv_diag_blocks, diags,
+                          /*broadcast_dimensions=*/{0, 1});
+
+    // Reshape back to original batch major dimensions
+    return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
+  });
+}
+
+XlaOp TriangularSolveExpander::BuildTriangularSolve(
+    XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
+    bool conjugate_a, bool unit_diagonal, int64 block_size,
+    PrecisionConfig::Precision precision) {
   XlaBuilder* builder = a.builder();
   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
@@ -422,6 +431,11 @@
       return b;
     }
 
+    // Degenerate case: 1x1 matrices.
+    if (ShapeUtil::GetDimension(a_shape, -1) == 1) {
+      return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
+    }
+
     // TODO(phawkins): consider pushing triangle masking into
     // InvertDiagonalBlocks.
     if (unit_diagonal) {
@@ -440,8 +454,7 @@
     auto diag_blocks = DiagonalBlocks(a, block_size);
 
     // We invert these blocks in parallel using batched matrix-vector products
-    auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
-                                                conjugate_a, precision);
+    auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
 
     // We now find the solution using GEMMs
     auto x =
@@ -452,8 +465,6 @@
   });
 }
 
-}  // namespace
-
 TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
     : block_size_(block_size) {}
 
diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h
index 362e855..3f9e58a 100644
--- a/tensorflow/compiler/xla/service/triangular_solve_expander.h
+++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_
 
 #include "absl/container/flat_hash_map.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/service/op_expander_pass.h"
 
 namespace xla {
@@ -35,6 +36,14 @@
   StatusOr<HloInstruction*> ExpandInstruction(
       HloInstruction* instruction) override;
 
+  virtual XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower_triangular,
+                                     PrecisionConfig::Precision precision);
+
+  XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
+                             bool transpose_a, bool conjugate_a,
+                             bool unit_diagonal, int64 block_size,
+                             PrecisionConfig::Precision precision);
+
  private:
   // Block size for BuildTriangularSolve
   const int64 block_size_;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index c66f9d9..e2b977a 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -333,10 +333,10 @@
   auto builder = HloComputation::Builder(TestName());
   auto constant = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
-  auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
+  auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
       ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
                                  ShapeUtil::MakeShape(U32, {})}),
-      HloOpcode::kCopyStart, constant));
+      constant));
   auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
       constant->shape(), HloOpcode::kCopyDone, copy_start));
 
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index c80123b..785fdec 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -37,23 +37,15 @@
 using absl::optional;
 using hlo_query::ContainsInstrWithOpcode;
 
-// Tries to remove elements in a while loop's tuple that aren't used within the
-// loop.
-//
-// Specifically, if a loop is tuple-shaped, and there exists some element of
-// that tuple that is not used by the loop condition and is not used by the loop
-// body except to pass it to the next iteration of the loop, then we can remove
-// that element from the loop's tuples.
-static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
-  CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
-
-  // Don't try this transformation if the while loop isn't removable, since if
-  // it succeeds ultimately we're going to have to replace the old while loop
-  // with a new one.
-  if (!while_op->parent()->IsSafelyRemovable(while_op)) {
-    VLOG(2) << "Can't remove dead parameters from non-removable while op.";
-    return false;
-  }
+// This is a utility function that removes the given tuple indices from the
+// while loop init, body, and condition. The final shape returned is still the
+// same as before.
+static StatusOr<HloInstruction*> RemoveDeadTupleIndices(
+    HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) {
+  // Build up maps from the old/new to the new/old tuple indices.
+  std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
+                                          used_tuple_indices.end());
+  absl::c_sort(new_to_old_tuple_idx);
 
   HloModule* module = while_op->GetModule();
   HloComputation* computation = while_op->parent();
@@ -62,107 +54,8 @@
   HloComputation* while_body = while_op->while_body();
   HloInstruction* while_body_root = while_body->root_instruction();
 
-  if (!while_init->shape().IsTuple()) {
-    VLOG(2) << "While op's carried value isn't tuple shaped.";
-    return false;
-  }
-
-  if (while_body_root->opcode() != HloOpcode::kTuple) {
-    VLOG(2) << "While body's root is not a tuple(...) instruction.";
-    return false;
-  }
-
   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
 
-  // Bail if param0 of while_cond or while_body has users which aren't of type
-  // get-tuple-element.
-  for (const HloInstruction* instr : {while_body->parameter_instruction(0),
-                                      while_cond->parameter_instruction(0)}) {
-    for (const HloInstruction* user : instr->users()) {
-      if (user->opcode() != HloOpcode::kGetTupleElement) {
-        VLOG(2) << "Cowardly refusing to analyze while loop with "
-                << instr->ToString(print_no_metadata)
-                << " used by non-GTE instruction "
-                << user->ToString(print_no_metadata) << " in computation "
-                << instr->parent()->name();
-        return false;
-      }
-    }
-  }
-
-  const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
-  if (tuple_size == 0) {
-    VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
-               "empty.";
-    return false;
-  }
-
-  absl::flat_hash_set<int64> used_tuple_indices;
-  for (HloComputation* comp : {while_body, while_cond}) {
-    // The HLO verifier ensures that while_input's shape matches while_init's
-    // shape, which we verified above is a tuple.
-    HloInstruction* while_input = comp->parameter_instruction(0);
-
-    for (const HloInstruction* user : while_input->users()) {
-      // This user doesn't count if it's only used by the while body's root, and
-      // the root places the tuple element into the same index of the tuple as
-      // it came from.  That just amounts to us carrying the variable through
-      // the loop.
-      //
-      // Careful: HloInstruction::operand_index returns the first index the
-      // operand appears in, but it may appear more than once!
-      if (user->user_count() == 1 && user->users().front() == while_body_root &&
-          while_body_root->operand_index(user) == user->tuple_index() &&
-          absl::c_count(while_body_root->operands(), user) == 1) {
-        continue;
-      }
-
-      used_tuple_indices.insert(user->tuple_index());
-      if (used_tuple_indices.size() == tuple_size) {
-        VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
-                << " uses all of its inputs; no simplification possible.";
-        return false;
-      }
-    }
-  }
-
-  // If a tuple element is not passed unmodified from the while body's param0
-  // through to the while body's root, count that element as "used", since
-  // removing that element would be observable.
-  for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
-    if (used_tuple_indices.contains(i)) {
-      continue;
-    }
-
-    auto* operand = while_body_root->operand(i);
-    if (operand->opcode() != HloOpcode::kGetTupleElement ||
-        operand->operand(0) != while_body->parameter_instruction(0) ||
-        operand->tuple_index() != i) {
-      VLOG(2) << "Tuple index " << i
-              << " is not passed through loop body unmodified.";
-      used_tuple_indices.insert(i);
-
-      if (used_tuple_indices.size() == tuple_size) {
-        VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
-                << " uses all of its inputs; no simplification possible.";
-        return false;
-      }
-    }
-  }
-
-  // If we got here, used_tuple_indices.size() < tuple_size, meaning some
-  // elements of the loop's tuple aren't used by while_body or while_cond.
-  CHECK_LT(used_tuple_indices.size(), tuple_size);
-
-  VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
-          << " elements from tuple of "
-          << while_op->ToString(print_no_metadata);
-
-  // Build up maps from the old/new to the new/old tuple indices.
-  std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
-                                          used_tuple_indices.end());
-  absl::c_sort(new_to_old_tuple_idx);
-
   absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
   for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
     int64 old_idx = new_to_old_tuple_idx[new_idx];
@@ -288,6 +181,7 @@
   // The tuple simplifier will then simplify this if possible, removing
   // new_tuple and while_init.
   std::vector<HloInstruction*> new_tuple_elems;
+  const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
   for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
     auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
     if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
@@ -305,9 +199,293 @@
   HloInstruction* new_tuple =
       computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple));
+
+  return new_while_op;
+}
+
+// Tries to remove elements in a while loop's tuple that aren't used within the
+// loop.
+//
+// Specifically, if a loop is tuple-shaped, and there exists some element of
+// that tuple that is not used by the loop condition and is not used by the loop
+// body except to pass it to the next iteration of the loop, then we can remove
+// that element from the loop's tuples.
+static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
+  CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
+
+  // Don't try this transformation if the while loop isn't removable, since if
+  // it succeeds ultimately we're going to have to replace the old while loop
+  // with a new one.
+  if (!while_op->parent()->IsSafelyRemovable(while_op)) {
+    VLOG(2) << "Can't remove dead parameters from non-removable while op.";
+    return false;
+  }
+
+  HloInstruction* while_init = while_op->mutable_operand(0);
+  HloComputation* while_cond = while_op->while_condition();
+  HloComputation* while_body = while_op->while_body();
+  HloInstruction* while_body_root = while_body->root_instruction();
+
+  if (!while_init->shape().IsTuple()) {
+    VLOG(2) << "While op's carried value isn't tuple shaped.";
+    return false;
+  }
+
+  if (while_body_root->opcode() != HloOpcode::kTuple) {
+    VLOG(2) << "While body's root is not a tuple(...) instruction.";
+    return false;
+  }
+
+  auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
+
+  // Bail if param0 of while_cond or while_body has users which aren't of type
+  // get-tuple-element.
+  for (const HloInstruction* instr : {while_body->parameter_instruction(0),
+                                      while_cond->parameter_instruction(0)}) {
+    for (const HloInstruction* user : instr->users()) {
+      if (user->opcode() != HloOpcode::kGetTupleElement) {
+        VLOG(2) << "Cowardly refusing to analyze while loop with "
+                << instr->ToString(print_no_metadata)
+                << " used by non-GTE instruction "
+                << user->ToString(print_no_metadata) << " in computation "
+                << instr->parent()->name();
+        return false;
+      }
+    }
+  }
+
+  const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
+  if (tuple_size == 0) {
+    VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
+               "empty.";
+    return false;
+  }
+
+  absl::flat_hash_set<int64> used_tuple_indices;
+  for (HloComputation* comp : {while_body, while_cond}) {
+    // The HLO verifier ensures that while_input's shape matches while_init's
+    // shape, which we verified above is a tuple.
+    HloInstruction* while_input = comp->parameter_instruction(0);
+
+    for (const HloInstruction* user : while_input->users()) {
+      // This user doesn't count if it's only used by the while body's root, and
+      // the root places the tuple element into the same index of the tuple as
+      // it came from.  That just amounts to us carrying the variable through
+      // the loop.
+      //
+      // Careful: HloInstruction::operand_index returns the first index the
+      // operand appears in, but it may appear more than once!
+      if (user->user_count() == 1 && user->users().front() == while_body_root &&
+          while_body_root->operand_index(user) == user->tuple_index() &&
+          absl::c_count(while_body_root->operands(), user) == 1) {
+        continue;
+      }
+
+      used_tuple_indices.insert(user->tuple_index());
+      if (used_tuple_indices.size() == tuple_size) {
+        VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
+                << " uses all of its inputs; no simplification possible.";
+        return false;
+      }
+    }
+  }
+
+  // If a tuple element is not passed unmodified from the while body's param0
+  // through to the while body's root, count that element as "used", since
+  // removing that element would be observable.
+  for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
+    if (used_tuple_indices.contains(i)) {
+      continue;
+    }
+
+    auto* operand = while_body_root->operand(i);
+    if (operand->opcode() != HloOpcode::kGetTupleElement ||
+        operand->operand(0) != while_body->parameter_instruction(0) ||
+        operand->tuple_index() != i) {
+      VLOG(2) << "Tuple index " << i
+              << " is not passed through loop body unmodified.";
+      used_tuple_indices.insert(i);
+
+      if (used_tuple_indices.size() == tuple_size) {
+        VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
+                << " uses all of its inputs; no simplification possible.";
+        return false;
+      }
+    }
+  }
+
+  // If we got here, used_tuple_indices.size() < tuple_size, meaning some
+  // elements of the loop's tuple aren't used by while_body or while_cond.
+  CHECK_LT(used_tuple_indices.size(), tuple_size);
+
+  VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
+          << " elements from tuple of "
+          << while_op->ToString(print_no_metadata);
+
+  TF_ASSIGN_OR_RETURN(while_op,
+                      RemoveDeadTupleIndices(while_op, used_tuple_indices));
+
   return true;
 }
 
+// This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes
+// duplicates by replacing them with tuple_index, followed by a call to
+// RemoveDeadTupleIndices.
+static StatusOr<HloInstruction*> TryRemoveRepeatedWhileTupleIndicesHelper(
+    HloInstruction* while_op, const int64 tuple_index,
+    absl::flat_hash_set<int64>& duplicates) {
+  HloComputation* while_cond = while_op->while_condition();
+  HloComputation* while_body = while_op->while_body();
+  HloInstruction* while_init = while_op->mutable_operand(0);
+
+  VLOG(2) << "while_init " << while_init->ToString() << " operands "
+          << while_init->operand_count();
+  VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString()
+          << " operands " << while_body->root_instruction()->operand_count();
+
+  // Change the loop body and condition such that uses of the duplicates are
+  // replaced with the original tuple element.
+  for (HloComputation* comp : {while_body, while_cond}) {
+    auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+        comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index),
+        comp->parameter_instruction(0), tuple_index));
+
+    std::vector<HloInstruction*> instrs_to_replace;
+    for (auto* instr : comp->instructions()) {
+      if (instr->opcode() == HloOpcode::kGetTupleElement &&
+          duplicates.contains(instr->tuple_index()) &&
+          instr->operand(0) == comp->parameter_instruction(0)) {
+        instrs_to_replace.push_back(instr);
+      }
+    }
+
+    for (auto instr : instrs_to_replace) {
+      TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get));
+    }
+  }
+
+  // We know which tuple indices are useful; i.e, those which aren't duplicates.
+  absl::flat_hash_set<int64> used_tuple_indices;
+  for (int index = 0; index < while_init->shape().tuple_shapes_size();
+       ++index) {
+    if (!duplicates.count(index)) {
+      used_tuple_indices.insert(index);
+    }
+  }
+
+  // Remove the duplicate tuple elements.
+  TF_ASSIGN_OR_RETURN(while_op,
+                      RemoveDeadTupleIndices(while_op, used_tuple_indices));
+
+  return while_op;
+}
+
+// If the while loop init passes the same values to several tuple indices, and
+// if the body keeps on passing them through, we can remove the duplicates.
+static StatusOr<bool> TryRemoveRepeatedWhileTupleIndices(
+    HloInstruction* while_op) {
+  CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
+
+  int index_to_investigate = 0;
+  // Don't try this transformation if the while loop isn't removable, since if
+  // it succeeds ultimately we're going to have to replace the old while loop
+  // with a new one.
+  if (!while_op->parent()->IsSafelyRemovable(while_op)) {
+    VLOG(2) << "Can't remove dead parameters from non-removable while op.";
+    return false;
+  }
+
+  HloInstruction* while_init = while_op->mutable_operand(0);
+  HloComputation* while_cond = while_op->while_condition();
+  HloComputation* while_body = while_op->while_body();
+  HloInstruction* while_body_root = while_body->root_instruction();
+
+  if (!while_init->shape().IsTuple()) {
+    VLOG(2) << "While op's carried value isn't tuple shaped.";
+    return false;
+  }
+
+  bool changed = false;
+  while (index_to_investigate < while_init->shape().tuple_shapes_size()) {
+    if (!while_init->shape().IsTuple() ||
+        while_init->opcode() != HloOpcode::kTuple) {
+      VLOG(2) << "While op's carried value isn't tuple shaped.";
+      return false;
+    }
+
+    if (while_body_root->opcode() != HloOpcode::kTuple) {
+      VLOG(2) << "While body's root is not a tuple(...) instruction.";
+      return false;
+    }
+
+    auto& while_shape = while_init->shape();
+    VLOG(2) << "Iterating " << index_to_investigate;
+
+    absl::flat_hash_set<int64> duplicates;
+    auto* pivot_init_elem = while_init->operand(index_to_investigate);
+    auto* pivot_body_elem = while_body_root->operand(index_to_investigate);
+    if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement &&
+        pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) {
+      if (pivot_body_elem->tuple_index() != index_to_investigate) {
+        VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() "
+                << pivot_body_elem->tuple_index() << " index_to_investigate "
+                << index_to_investigate;
+        index_to_investigate++;
+        continue;
+      }
+    } else {
+      index_to_investigate++;
+      continue;
+    }
+
+    // Look from index_to_investigate onwards to see if it is repeated.
+    for (int64 i = index_to_investigate + 1;
+         i < while_shape.tuple_shapes_size(); ++i) {
+      auto* init_elem = while_init->operand(i);
+      auto* body_elem = while_body_root->operand(i);
+      if (body_elem->opcode() == HloOpcode::kGetTupleElement &&
+          body_elem->operand(0) == while_body->parameter_instruction(0)) {
+        if (body_elem->tuple_index() != i) {
+          VLOG(2) << "Mismatch between body_elem->tuple_index() "
+                  << body_elem->tuple_index() << " i " << i;
+          continue;
+        }
+      } else {
+        continue;
+      }
+
+      if (pivot_init_elem == init_elem) {
+        VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem "
+                << pivot_init_elem->ToString();
+        VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem "
+                << pivot_body_elem->ToString();
+        duplicates.insert(i);
+      }
+    }
+
+    // If duplicates are found, call the helper to remove them.
+    if (!duplicates.empty()) {
+      VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init "
+              << pivot_init_elem->ToString();
+      TF_ASSIGN_OR_RETURN(while_op,
+                          TryRemoveRepeatedWhileTupleIndicesHelper(
+                              while_op, index_to_investigate, duplicates));
+      changed = true;
+      VLOG(2) << "Changed while_op " << while_op->ToString()
+              << " while_op operand count " << while_op->operand_count();
+      // Update the while loop variables so we can continue looking for
+      // duplicates of a different index.
+      while_init = while_op->mutable_operand(0);
+      while_cond = while_op->while_condition();
+      while_body = while_op->while_body();
+      while_body_root = while_body->root_instruction();
+    }
+    index_to_investigate++;
+  }
+
+  return changed;
+}
+
 // Removes each loop parameter (i.e. member of the while loop tuple) that is a
 // constant and is the same in the while loop body and the while loop init.
 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
@@ -1048,6 +1226,7 @@
 
     TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
     changed |= result;
+
     if (result) {
       // Don't continue simplifying after successfully removing the while loop
       // -- that would result in use-after-free nastiness.
@@ -1067,6 +1246,12 @@
     // successful, meaning that `while_op` is no longer valid after one of these
     // transformations returns true.
 
+    TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op));
+    changed |= result;
+    if (result) {
+      continue;
+    }
+
     TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
     changed |= result;
     if (result) {
@@ -1074,6 +1259,7 @@
     }
 
     TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
+
     changed |= result;
     if (result) {
       continue;
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index d715fb3..c93cb5d 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -794,5 +794,51 @@
           .ValueOrDie());
 }
 
+TEST_F(WhileLoopSimplifierTest, RemoveRepeatedParams) {
+  const string hlo_string = R"(
+  HloModule SwappingTupleElements
+
+  SwappingTupleElements.body {
+    loop_var = (s32[], s32[], s32[]) parameter(0)
+    get-tuple-element = s32[] get-tuple-element(loop_var), index=0
+    get-tuple-element.1 = s32[] get-tuple-element(loop_var), index=1
+    get-tuple-element.2 = s32[] get-tuple-element(loop_var), index=2
+    y = s32[] add(get-tuple-element.1, get-tuple-element.2)
+    ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element, y,
+      s32[] get-tuple-element.2)
+  }
+
+  SwappingTupleElements.always_true {
+   param = (s32[], s32[], s32[]) parameter(0)
+   get-tuple-element = s32[] get-tuple-element(param), index=0
+   get-tuple-element.1 = s32[] get-tuple-element(param), index=1
+   ROOT less-than = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
+  }
+
+  ENTRY SwappingTupleElements {
+   x = s32[] parameter(0)
+   y = s32[] parameter(1)
+   tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] y, s32[] x)
+   ROOT while = (s32[], s32[], s32[]) while(tuple.1),
+     condition=SwappingTupleElements.always_true,
+     body=SwappingTupleElements.body
+  }
+  )";
+
+  auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
+  HloInstruction* new_while = FindFirstWhile(m.get());
+  Shape new_while_shape = ParseShape("(s32[], s32[])").ValueOrDie();
+  EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
+  EXPECT_TRUE(ShapeUtil::Equal(
+      new_while->while_body()->root_instruction()->shape(), new_while_shape));
+  EXPECT_TRUE(ShapeUtil::Equal(
+      new_while->while_body()->parameter_instruction(0)->shape(),
+      new_while_shape));
+  EXPECT_TRUE(ShapeUtil::Equal(
+      new_while->while_condition()->parameter_instruction(0)->shape(),
+      new_while_shape));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 73bb332..bc48a9c 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -648,7 +648,9 @@
                                    const ShapeIndex& target_base_index) {
   CHECK(ShapeUtil::Compatible(
       ShapeUtil::GetSubshape(shape(), target_base_index),
-      ShapeUtil::GetSubshape(other.shape(), source_base_index)));
+      ShapeUtil::GetSubshape(other.shape(), source_base_index)))
+      << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs "
+      << ShapeUtil::GetSubshape(other.shape(), source_base_index);
   ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
                             const ShapeIndex& index, T* data) {
     // Copy the data element only if index is in the
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 17444c0..734d2ed 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2699,5 +2699,6 @@
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
+        "//tensorflow/core/platform:tf32_utils",
     ],
 )
diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc
index e7f5ca5..9a86852 100644
--- a/tensorflow/compiler/xla/tests/cholesky_test.cc
+++ b/tensorflow/compiler/xla/tests/cholesky_test.cc
@@ -30,6 +30,7 @@
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/tf32_utils.h"
 
 namespace xla {
 namespace {
@@ -181,6 +182,7 @@
       public ::testing::WithParamInterface<CholeskyTestCase> {};
 
 XLA_TEST_P(RandomCholeskyTest, Random) {
+  tensorflow::allow_tf32_execution(false);  // Test fails with tf32 allowed
   XlaBuilder builder(TestName());
 
   auto test_params = GetParam();
diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
index ba4092d..a7e0324 100644
--- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc
@@ -104,12 +104,26 @@
   }
 }
 
+TEST_F(DynamismInferenceTest, TupleSimple) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto c = ConstantR0<int32>(&b, 42);
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
+
+    auto tuple = Tuple(&b, {c, p});
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(),
+              false);
+    EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true);
+  }
+}
+
 TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
     XlaBuilder b(TestName());
     auto c = ConstantR0<int32>(&b, 42);
-    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
 
     auto tuple = Tuple(&b, {c, p});
     auto gte0 = GetTupleElement(tuple, 0);
@@ -122,12 +136,25 @@
   }
 }
 
+TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
+  for (ClientType client_type : client_types) {
+    Client* client = ClientOrDie(platform_, client_type);
+    XlaBuilder b(TestName());
+    auto c = ConstantR0<int32>(&b, 42);
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
+    auto pred = Eq(c, p);
+    auto result = Select(pred, p, c);
+    EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(),
+              false);
+  }
+}
+
 TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
     XlaBuilder b(TestName());
     auto c = ConstantR0<int32>(&b, 42);
-    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
 
     auto concat = ConcatScalars(&b, {c, p});
     auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
@@ -146,7 +173,7 @@
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
     XlaBuilder b(TestName());
-    auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+    auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
 
     auto value = ComputeDynamismScalar(client, computation, &b);
     ASSERT_TRUE(value.ok()) << value.status();
@@ -160,7 +187,7 @@
     Client* client = ClientOrDie(platform_, client_type);
     XlaBuilder b(TestName());
     auto c = ConstantR0<int32>(&b, 42);
-    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
 
     auto neg0 = Neg(c);
     auto neg1 = Neg(p);
@@ -177,7 +204,7 @@
     Client* client = ClientOrDie(platform_, client_type);
     XlaBuilder b(TestName());
     auto c = ConstantR0<int32>(&b, 42);
-    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0");
+    auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
 
     // Static value + static value = static
     auto add1 = Add(c, c);
@@ -198,8 +225,8 @@
     // param = Param([<=2, 3])
     // get_dimension_size(param, 0) is dynamic
     // get_dimension_size(param, 1) is static
-    auto p =
-        Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0");
+    auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}),
+                       "p0");
 
     auto gds0 = GetDimensionSize(p, 0);
     auto gds1 = GetDimensionSize(p, 1);
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 0fd5f19..0f8a4c1 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -711,6 +711,24 @@
   RunTest(hlo_text, &operand, &start_indices);
 }
 
+XLA_TEST_F(GatherOperationTest, GatherFromScalarNonZeroIndices) {
+  const string hlo_text = R"(
+HloModule GatherFromScalar
+
+ENTRY main {
+  operand = f32[1,1,1] parameter(0)
+  indices = s32[2,3,50] parameter(1)
+  ROOT gather = f32[1,2,50] gather(operand, indices),
+      offset_dims={0},
+      collapsed_slice_dims={0,1},
+      start_index_map={1,0,2},
+      index_vector_dim=1,
+      slice_sizes={1,1,1}
+}
+)";
+  EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0, 0}));
+}
+
 class GatherClientLibraryTest : public ClientLibraryTestBase {};
 
 // Disabled on interpreter since ExecuteAsyncOnStream is not supported.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6f5e366..3068a01 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -670,8 +670,9 @@
         ":lib",
         ":protos_all_cc",
         # TODO(b/162630222): remove this dependency.
-        "//tensorflow/c/kernels:summary_op_lib",
         "//tensorflow/c/kernels:histogram_summary_op_lib",
+        "//tensorflow/c/kernels:merge_summary_op_lib",
+        "//tensorflow/c/kernels:summary_op_lib",
     ],
 )
 
@@ -876,6 +877,7 @@
         ":word2vec_ops",
         "//tensorflow/c/kernels:bitcast_op_lib",
         "//tensorflow/c/kernels:histogram_summary_op_lib",
+        "//tensorflow/c/kernels:merge_summary_op_lib",
         "//tensorflow/c/kernels:summary_op_lib",
         "//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op",
     ] + if_chromiumos(
@@ -984,9 +986,10 @@
     name = "all_kernels_impl",
     visibility = [":__subpackages__"],
     deps = [
-        "//tensorflow/c/kernels:histogram_summary_op",
-        "//tensorflow/c/kernels:summary_op",
         "//tensorflow/c/kernels:bitcast_op",
+        "//tensorflow/c/kernels:histogram_summary_op",
+        "//tensorflow/c/kernels:merge_summary_op",
+        "//tensorflow/c/kernels:summary_op",
         "//tensorflow/core/kernels:array",
         "//tensorflow/core/kernels:audio",
         "//tensorflow/core/kernels:batch_kernels",
@@ -2962,6 +2965,8 @@
     srcs = [
         # PNG data
         "//tensorflow/core/lib/png:testdata",
+        "//tensorflow/core/lib/ssim:testdata",
+        "//tensorflow/core/lib/psnr:testdata",
         # JPEG data
         "lib/jpeg/testdata/jpeg_merge_test1.jpg",
         "lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg",
@@ -2991,37 +2996,13 @@
         "lib/bmp/testdata/grayscale_small.bmp",
         "lib/bmp/testdata/grayscale_small_3channels.bmp",
         "lib/bmp/testdata/grayscale_small_4channels.bmp",
-        # SSIM, PSNR data
-        "lib/ssim/testdata/checkerboard1.png",
-        "lib/ssim/testdata/checkerboard2.png",
-        "lib/ssim/testdata/checkerboard3.png",
-        "lib/psnr/testdata/cat_q20.jpg",
-        "lib/psnr/testdata/cat_q72.jpg",
-        "lib/psnr/testdata/cat_q95.jpg",
     ],
     visibility = ["//visibility:public"],
 )
 
-filegroup(
+alias(
     name = "lmdb_testdata",
-    testonly = 1,
-    srcs = [
-        # A simple key-value store:
-        #   0 : 'b'
-        #   1 : 'b'
-        #    ...
-        #   9 : 'b'
-        # Which is then overwritten with:
-        #   0 : 'a'
-        #   1 : 'b'
-        #    ...
-        #   9 : 'j'
-        "lib/lmdb/testdata/data.mdb",
-        # LMDB, being a memory-mapped database, uses a different file format on
-        # big-endian systems.
-        "lib/lmdb/testdata/data_bigendian.mdb",
-    ],
-    visibility = ["//visibility:public"],
+    actual = "//tensorflow/core/lib/lmdb:lmdb_testdata",
 )
 
 alias(
diff --git a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt
index 2184b64..dc018ae 100644
--- a/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Acos.pbtxt
@@ -1,4 +1,11 @@
 op {
   graph_op_name: "Acos"
   summary: "Computes acos of x element-wise."
+  description: <<END
+
+  Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`.
+
+  Input range is `[-1, 1]` and the output has a range of `[0, pi]`.
+
+END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_Add.pbtxt b/tensorflow/core/api_def/base_api/api_def_Add.pbtxt
index 7a408af..db0c12a 100644
--- a/tensorflow/core/api_def/base_api/api_def_Add.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Add.pbtxt
@@ -4,5 +4,10 @@
   description: <<END
 *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+
+Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor.
+
+Both input and output have a range `(-inf, inf)`.
+
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyCenteredRMSProp.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyCenteredRMSProp.pbtxt
index c88d18d..03f4020 100644
--- a/tensorflow/core/api_def/base_api/api_def_ApplyCenteredRMSProp.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ApplyCenteredRMSProp.pbtxt
@@ -37,6 +37,12 @@
 END
   }
   in_arg {
+    name: "momentum"
+    description: <<END
+Momentum Scale. Must be a scalar.
+END
+  }
+  in_arg {
     name: "epsilon"
     description: <<END
 Ridge term. Must be a scalar.
diff --git a/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt b/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt
index d87c088..5e73607 100644
--- a/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt
@@ -24,8 +24,27 @@
 destination data format.
 END
   }
-  summary: "Returns the permuted vector/tensor in the destination data format given the"
+  summary: "Permute input tensor from `src_format` to `dst_format`."
   description: <<END
-one in the source data format.
+Input tensor must be a vector of size 4, or a 4x2 tensor.
+
+For example, with `src_format` of `NHWC`, `dst_format` of `NCHW`, and inputs:
+```
+[1, 2, 3, 4]
+```
+and
+```
+[[1, 2, 3, 4],
+ [5, 6, 7, 8]]
+```
+, the outputs will be (respectively):
+```
+[1, 4, 2, 3]
+```
+and
+```
+[[1, 4, 2, 3],
+ [5, 8, 6, 7]]
+```
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_LogMatrixDeterminant.pbtxt b/tensorflow/core/api_def/base_api/api_def_LogMatrixDeterminant.pbtxt
index 8245f7d..018326c 100644
--- a/tensorflow/core/api_def/base_api/api_def_LogMatrixDeterminant.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_LogMatrixDeterminant.pbtxt
@@ -26,9 +26,9 @@
 The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions
 form square matrices. The outputs are two tensors containing the signs and
 absolute values of the log determinants for all N input submatrices
-`[..., :, :]` such that the determinant = sign*exp(log_abs_determinant).
-The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU
-is the LU decomposition of the input and P is the corresponding
+`[..., :, :]` such that `determinant = sign*exp(log_abs_determinant)`.
+The `log_abs_determinant` is computed as `det(P)*sum(log(diag(LU)))` where `LU`
+is the `LU` decomposition of the input and `P` is the corresponding
 permutation matrix.
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceApplyCenteredRMSProp.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceApplyCenteredRMSProp.pbtxt
index 9cc033c..3f30676 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResourceApplyCenteredRMSProp.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceApplyCenteredRMSProp.pbtxt
@@ -37,6 +37,12 @@
 END
   }
   in_arg {
+    name: "momentum"
+    description: <<END
+Momentum Scale. Must be a scalar.
+END
+  }
+  in_arg {
     name: "epsilon"
     description: <<END
 Ridge term. Must be a scalar.
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorMapStackKeys.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorMapStackKeys.pbtxt
new file mode 100644
index 0000000..a8ecb43
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorMapStackKeys.pbtxt
@@ -0,0 +1,8 @@
+op {
+  graph_op_name: "TensorMapStackKeys"
+  summary: "Returns a Tensor stack of all keys in a tensor map."
+  description: <<END
+input_handle: the input map
+keys: the returned Tensor of all keys in the map
+END
+}
\ No newline at end of file
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
index 4cd9f82..b958a25 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local.cc
@@ -21,9 +21,6 @@
 
 void CollectiveRemoteAccessLocal::StartAbort(const Status& s) {
   buf_rendezvous_.StartAbort(s);
-  if (errors::IsFailedPrecondition(s)) {
-    dev_resolver_->ClearCache();
-  }
 }
 
 void CollectiveRemoteAccessLocal::RecvFromPeer(
@@ -108,6 +105,13 @@
                              from_alloc_attr, done);
 }
 
+void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
+                                                  const StatusCallback& done) {
+  // Assume local devices are always healthy.
+  done(errors::Internal(
+      "CheckPeerHealth is not supposed to be called for local collectives"));
+}
+
 /*static*/
 void CollectiveRemoteAccessLocal::MemCpyAsync(
     DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index 8a0bbd5..12aca90 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -53,6 +53,9 @@
                   const DeviceLocality& client_locality,
                   const StatusCallback& done) override;
 
+  void CheckPeerHealth(const string& peer_task,
+                       const StatusCallback& done) override;
+
   BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
 
   // Copy utility that always copies bytes from src to dst even if
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index d721fc3..2c60614 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -151,5 +151,16 @@
   EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor));
 }
 
+TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
+  Status status;
+  Notification done;
+  rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) {
+    status = s;
+    done.Notify();
+  });
+  done.WaitForNotification();
+  EXPECT_TRUE(errors::IsInternal(status));
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc
index 12e1e28..9a898e7 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.cc
+++ b/tensorflow/core/common_runtime/device_resolver_local.cc
@@ -46,4 +46,10 @@
   done(s);
 }
 
+Status DeviceResolverLocal::GetTaskCached(
+    const string& task, std::vector<DeviceAttributes>* attributes) {
+  return errors::Internal(
+      "GetTaskCached is not supposed to be called in local collectives");
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h
index 53a3c87..1cd84af 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.h
+++ b/tensorflow/core/common_runtime/device_resolver_local.h
@@ -39,9 +39,8 @@
                                 DeviceAttributes* attributes,
                                 const StatusCallback& done) override;
 
-  void ClearTask(const string& task) override {}
-
-  void ClearCache() override {}
+  Status GetTaskCached(const string& task,
+                       std::vector<DeviceAttributes>* attributes) override;
 
  protected:
   const DeviceMgr* dev_mgr_;
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 196c463..cebfbdd 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -72,8 +72,7 @@
 
 EagerContext::EagerContext(
     const SessionOptions& opts,
-    ContextDevicePlacementPolicy default_device_placement_policy,
-    ContextMirroringPolicy default_mirroring_policy, bool async,
+    ContextDevicePlacementPolicy default_device_placement_policy, bool async,
     const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr,
     bool device_mgr_owned, Rendezvous* rendezvous,
     const CustomKernelCreator* custom_kernel_creator,
@@ -81,7 +80,6 @@
     : ImmediateExecutionContext(kEager),
       opts_(opts),
       default_device_placement_policy_(default_device_placement_policy),
-      default_mirroring_policy_(default_mirroring_policy),
       local_device_manager_(device_mgr, device_mgr_owned),
       host_cpu_device_(device_mgr->HostCPU()),
       rendezvous_(rendezvous),
@@ -403,25 +401,6 @@
   return default_device_placement_policy_;
 }
 
-void EagerContext::SetThreadLocalMirroringPolicy(
-    ContextMirroringPolicy policy) {
-  mutex_lock ml(policy_map_mu_);
-  mirroring_policy_[std::this_thread::get_id()] = policy;
-}
-
-ContextMirroringPolicy EagerContext::GetMirroringPolicy() const {
-  tf_shared_lock l(policy_map_mu_);
-  auto policy_map_it = mirroring_policy_.find(std::this_thread::get_id());
-  if (policy_map_it != mirroring_policy_.end()) {
-    return policy_map_it->second;
-  }
-  return default_mirroring_policy_;
-}
-
-bool EagerContext::MirrorTensors() const {
-  return GetMirroringPolicy() == MIRRORING_ALL;
-}
-
 bool EagerContext::LazyCopyFunctionRemoteInputs() const {
   return lazy_copy_function_remote_inputs_;
 }
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 0daee41..3785df7 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -94,18 +94,6 @@
 };
 // LINT.ThenChange(//tensorflow/c/eager/c_api.h)
 
-// LINT.IfChange
-// Note: Keep in sync with exported copy of enum in eager/c_api_experimental.h.
-enum ContextMirroringPolicy {
-  // Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
-  // copies with their own lifetime.
-  MIRRORING_NONE = 0,
-  // Mirroring any remote tensor handles, associating them with the lifetime of
-  // the local TensorHandle.
-  MIRRORING_ALL = 1,
-};
-// LINT.ThenChange(//tensorflow/c/eager/c_api_experimental.h)
-
 class RunMetadataListener {
  public:
   virtual ~RunMetadataListener() {}
@@ -126,7 +114,7 @@
                                       const string& target_device_name,
                                       TensorHandle** result) = 0;
 
-  virtual Status Execute(EagerOperation* op, TensorHandle** retvals,
+  virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
                          int* num_retvals) = 0;
 };
 
@@ -149,8 +137,7 @@
 
   EagerContext(const SessionOptions& opts,
                ContextDevicePlacementPolicy default_device_placement_policy,
-               ContextMirroringPolicy default_mirroring_policy, bool async,
-               const bool lazy_copy_function_remote_inputs,
+               bool async, const bool lazy_copy_function_remote_inputs,
                const DeviceMgr* device_mgr, bool device_mgr_owned,
                Rendezvous* rendezvous,
                const CustomKernelCreator* custom_kernel_creator,
@@ -234,14 +221,6 @@
   Status SelectDevice(DeviceNameUtils::ParsedName preferred,
                       const NodeDef& ndef, Device** out) const;
 
-  // Sets the implicit copy policy for the current thread.
-  void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
-
-  // Returns the implicit copy policy for the current thread.
-  ContextMirroringPolicy GetMirroringPolicy() const;
-
-  bool MirrorTensors() const;
-
   bool LazyCopyFunctionRemoteInputs() const;
 
   bool FindFunctionByName(const string& name) const;
@@ -557,15 +536,12 @@
 
   SessionOptions opts_;
   const ContextDevicePlacementPolicy default_device_placement_policy_;
-  const ContextMirroringPolicy default_mirroring_policy_;
 
   // Note: we cannot use C++11 thread_local here as there is no concept of a
   // thread-local-object-local variable in C++11.
   mutable mutex policy_map_mu_;
   std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
       device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
-  std::unordered_map<std::thread::id, ContextMirroringPolicy> mirroring_policy_
-      TF_GUARDED_BY(policy_map_mu_);
 
   OwnedOrUnownedHelper<const DeviceMgr> local_device_manager_;
   // Maintain copy of all previously created local device managers.
diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc
index e577b1d..10ca2eb 100644
--- a/tensorflow/core/common_runtime/eager/context_test.cc
+++ b/tensorflow/core/common_runtime/eager/context_test.cc
@@ -56,7 +56,6 @@
     InitDeviceManager();
     context_ = new EagerContext(
         opts, policy,
-        /* default_mirroring_policy */ MIRRORING_NONE,
         /* async */ false,
         /* lazy_copy_function_remote_inputs */ false, device_manager_,
         /* device_mgr_owned */ false, /* rendezvous */ nullptr,
diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc
index ff63c70..d1e1218 100644
--- a/tensorflow/core/common_runtime/eager/core.cc
+++ b/tensorflow/core/common_runtime/eager/core.cc
@@ -208,12 +208,18 @@
       device = ctx_.HostCPU();
     }
   }
+
+  tensorflow::TensorHandle** retval_array =
+      reinterpret_cast<tensorflow::TensorHandle**>(retvals.data());
+  if (VariantDeviceIsCustom(device)) {
+    return absl::get<CustomDevice*>(device)->Execute(this, retval_array,
+                                                     num_retvals);
+  }
+
   if (device != kVariantDeviceNull) {
     SetDevice(device);
   }
-  return EagerExecute(
-      this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
-      num_retvals);
+  return EagerExecute(this, retval_array, num_retvals);
 }
 
 }  //  namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc
index b433cc4..25b03e2 100644
--- a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc
+++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc
@@ -47,9 +47,8 @@
       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
       SessionOptions(),
-      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
-      &device_mgr, false, nullptr, nullptr);
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
+      false, &device_mgr, false, nullptr, nullptr);
   EagerOperation orig_op(ctx);
   std::unique_ptr<tensorflow::EagerOperation> out_op;
   EXPECT_EQ(Status::OK(),
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc
index 947b67a..6d1ecf6 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.cc
+++ b/tensorflow/core/common_runtime/eager/eager_operation.cc
@@ -277,11 +277,6 @@
   return InferInputListAttrs(inputs.size());
 }
 
-Status EagerOperation::SetUseXla(bool enable) {
-  use_xla_ = enable;
-  return Status::OK();
-}
-
 Status EagerOperation::Reset(
     const char* op, const char* device_name, bool remote,
     EagerExecutor* executor,
@@ -313,7 +308,6 @@
         "registered in the binary running in this process.");
   }
   attrs_.Reset(op);
-  use_xla_ = false;
   stack_trace_.reset();
   is_function_ = is_function;
   cancellation_manager_ = nullptr;
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h
index 327411e..2e35dd4 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.h
+++ b/tensorflow/core/common_runtime/eager/eager_operation.h
@@ -120,8 +120,6 @@
   Status InputLength(const char* input_name, int* length) override;
   Status OutputLength(const char* output_name, int* length) override;
 
-  Status SetUseXla(bool enable) override;
-
   void SetStackTrace(AbstractStackTrace stack_trace) override {
     stack_trace_ = stack_trace;
   }
@@ -227,7 +225,6 @@
   // updated accordingly.
   VariantDevice device_;
 
-  bool use_xla_ = false;
   absl::optional<AbstractStackTrace> stack_trace_;
   bool is_function_;  // Conceptually const, but can't be because of Reset
   bool colocation_exempt_;
@@ -257,6 +254,11 @@
   return down_cast<EagerOperation*>(operation);
 }
 
+inline const EagerOperation* OperationFromInterface(
+    const ImmediateExecutionOperation* operation) {
+  return down_cast<const EagerOperation*>(operation);
+}
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
diff --git a/tensorflow/core/common_runtime/eager/eager_operation_test.cc b/tensorflow/core/common_runtime/eager/eager_operation_test.cc
index 352c7f0..0c98db9 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation_test.cc
+++ b/tensorflow/core/common_runtime/eager/eager_operation_test.cc
@@ -27,9 +27,8 @@
       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
   auto ctx = new EagerContext(
       SessionOptions(),
-      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
-      &device_mgr, false, nullptr, nullptr, nullptr);
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
+      false, &device_mgr, false, nullptr, nullptr, nullptr);
 
   auto op = new EagerOperation(ctx);
 
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 2458214..cfb849c 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -1070,11 +1070,6 @@
       [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
       profiler::TraceMeLevel::kInfo);
 
-  if (VariantDeviceIsCustom(op->Device())) {
-    return absl::get<CustomDevice*>(op->Device())
-        ->Execute(op, retvals, num_retvals);
-  }
-
   if (!op->Executor().Async()) {
     // In sync mode, always clear error to maintain the same behavior as before.
     // TODO(b/141004939): Remove this.
diff --git a/tensorflow/core/common_runtime/eager/execute_node_test.cc b/tensorflow/core/common_runtime/eager/execute_node_test.cc
index 54df63f..b7fae20 100644
--- a/tensorflow/core/common_runtime/eager/execute_node_test.cc
+++ b/tensorflow/core/common_runtime/eager/execute_node_test.cc
@@ -67,9 +67,8 @@
 
   auto ctx = new EagerContext(
       SessionOptions(),
-      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
-      &device_mgr, false, nullptr, nullptr, nullptr);
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
+      false, &device_mgr, false, nullptr, nullptr, nullptr);
 
   // Set a RemoteMgr to the EagerContext.
   auto remote_mgr = absl::make_unique<eager::RemoteMgr>(
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
index df4423d..105b0d9 100644
--- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
@@ -150,7 +150,7 @@
 Status MklEagerOpRewrite::CreateMklConv2DOp(
     EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
   const string mkl_op_name =
-      mkl_op_registry::GetMklEagerOpName(orig_op->Name());
+      mkl_op_registry::GetMklNativeOpName(orig_op->Name());
   TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_conv2d_op));
   return Status::OK();
 }
@@ -210,7 +210,7 @@
   if (element != mkl_eager_ops_.end()) {
     // Eager Op exists. So verify registry and return registered or not.
     return (mkl_op_registry::IsMklNameChangeOp(
-                mkl_op_registry::GetMklEagerOpName(op_name), dt) ||
+                mkl_op_registry::GetMklNativeOpName(op_name), dt) ||
             mkl_op_registry::IsMklNameChangeOp(
                 mkl_op_registry::GetMklOpName(op_name), dt));
   } else {
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc
index bcae4e6..5cf4021 100644
--- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc
@@ -40,8 +40,7 @@
     tensorflow::EagerContext* eager_ctx = new tensorflow::EagerContext(
         SessionOptions(),
         tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-        tensorflow::ContextMirroringPolicy::MIRRORING_NONE, async,
-        lazy_remote_tensor_copy, device_mgr.get(), false, rendezvous,
+        async, lazy_remote_tensor_copy, device_mgr.get(), false, rendezvous,
         GetDefaultCustomKernelCreator());
 
     EagerExecutor executor_(false);
@@ -74,7 +73,7 @@
     auto orig_op = CreateOp("Conv2D");                \
     orig_op->MutableAttrs()->Set("T", T);             \
     orig_op->MutableAttrs()->Set("padding", "VALID"); \
-    CheckRewrite(orig_op.get(), "_MklEagerConv2D");   \
+    CheckRewrite(orig_op.get(), "_MklNativeConv2D");  \
   }
 REGISTER_TEST_ALL_TYPES(Conv2D);
 #undef REGISTER_TEST
@@ -89,22 +88,22 @@
 REGISTER_TEST_ALL_TYPES(Conv2D_Explicit_Padding);
 #undef REGISTER_TEST
 
-#define REGISTER_TEST(NAME, T, INPUT)                            \
-  TEST_F(EagerOpRewriteTest, NAME##_##T) {                       \
-    auto orig_op = CreateOp("Conv2DBackpropInput");              \
-    orig_op->MutableAttrs()->Set("T", T);                        \
-    orig_op->MutableAttrs()->Set("padding", "VALID");            \
-    CheckRewrite(orig_op.get(), "_MklEagerConv2DBackpropInput"); \
+#define REGISTER_TEST(NAME, T, INPUT)                             \
+  TEST_F(EagerOpRewriteTest, NAME##_##T) {                        \
+    auto orig_op = CreateOp("Conv2DBackpropInput");               \
+    orig_op->MutableAttrs()->Set("T", T);                         \
+    orig_op->MutableAttrs()->Set("padding", "VALID");             \
+    CheckRewrite(orig_op.get(), "_MklNativeConv2DBackpropInput"); \
   }
 REGISTER_TEST_ALL_TYPES(Conv2DBackpropInput);
 #undef REGISTER_TEST
 
-#define REGISTER_TEST(NAME, T, INPUT)                             \
-  TEST_F(EagerOpRewriteTest, NAME##_##T) {                        \
-    auto orig_op = CreateOp("Conv2DBackpropFilter");              \
-    orig_op->MutableAttrs()->Set("T", T);                         \
-    orig_op->MutableAttrs()->Set("padding", "VALID");             \
-    CheckRewrite(orig_op.get(), "_MklEagerConv2DBackpropFilter"); \
+#define REGISTER_TEST(NAME, T, INPUT)                              \
+  TEST_F(EagerOpRewriteTest, NAME##_##T) {                         \
+    auto orig_op = CreateOp("Conv2DBackpropFilter");               \
+    orig_op->MutableAttrs()->Set("T", T);                          \
+    orig_op->MutableAttrs()->Set("padding", "VALID");              \
+    CheckRewrite(orig_op.get(), "_MklNativeConv2DBackpropFilter"); \
   }
 REGISTER_TEST_ALL_TYPES(Conv2DBackpropFilter);
 #undef REGISTER_TEST
diff --git a/tensorflow/core/common_runtime/eager/placement_test.cc b/tensorflow/core/common_runtime/eager/placement_test.cc
index 4ea38d2..8e923e6 100644
--- a/tensorflow/core/common_runtime/eager/placement_test.cc
+++ b/tensorflow/core/common_runtime/eager/placement_test.cc
@@ -85,7 +85,6 @@
     InitDeviceManager();
     context_ = new EagerContext(
         opts, policy,
-        /* default_mirroring_policy */ MIRRORING_NONE,
         /* async */ false,
         /* lazy_copy_function_remote_inputs */ false, device_manager_,
         /* device_mgr_owned */ false, /* rendezvous */ nullptr,
diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc
index 148c6c6..619715f 100644
--- a/tensorflow/core/common_runtime/eager/placement_utils.cc
+++ b/tensorflow/core/common_runtime/eager/placement_utils.cc
@@ -185,34 +185,35 @@
   if (VariantDeviceIsCustom(op.Device())) {
     *device = op.Device();
     return Status::OK();
+  } else if (!op.DeviceName().empty()) {
+    // Don't override explicit placements.
+    return Status::OK();
   }
 
+  // Ops are placed on a custom device if there's no other explicit requested
+  // placement and there is only one custom device in the op inputs.
   if (!op.Inputs().empty()) {
-    // We keep track of what we've seen with devices instead of booleans to be
-    // able to provide a meaningful error message below.
-    VariantDevice first = op.Inputs()[0]->device();
-    VariantDevice different = first;  // A different input device, if any.
-    VariantDevice custom = first;     // The first custom device seen, or an
-                                      // arbitrary non-custom device otherwise.
-    for (size_t i = 1; first == different && i < op.Inputs().size(); ++i) {
-      VariantDevice device = op.Inputs()[i]->device();
-      if (device != first) {
-        different = device;
-      }
-      if (!VariantDeviceIsCustom(custom) && VariantDeviceIsCustom(device)) {
-        custom = device;
-      }
-      if (different != first && VariantDeviceIsCustom(custom)) {
-        return errors::InvalidArgument(absl::StrCat(
-            "If an operation has one of its inputs in a custom device, then "
-            "all inputs should be on that same device. Operation ",
-            op.Name(), " has one input in custom device ",
-            VariantDeviceName(custom),
-            " and at least one input in a different device ",
-            VariantDeviceName(custom == first ? different : first)));
+    CustomDevice* first = nullptr;
+    for (const TensorHandle* input : op.Inputs()) {
+      if (VariantDeviceIsCustom(input->device())) {
+        CustomDevice* current = absl::get<CustomDevice*>(input->device());
+        if (first == nullptr) {
+          first = current;
+        } else if (first != current) {
+          return errors::InvalidArgument(absl::StrCat(
+              "If an operation has one of its inputs in a custom device, then "
+              "all inputs should be on that same custom device or another "
+              "physical device. Operation ",
+              op.Name(),
+              " has one input in custom "
+              "device ",
+              VariantDeviceName(first),
+              " and at least one input in a different custom device ",
+              VariantDeviceName(current)));
+        }
       }
     }
-    if (different == first && VariantDeviceIsCustom(custom)) {
+    if (first != nullptr) {
       *device = first;
       return Status::OK();
     }
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
index 6b3c464..7f707ee 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc
@@ -38,9 +38,8 @@
       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
   auto ctx = new EagerContext(
       SessionOptions(),
-      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
-      &device_mgr, false, nullptr, nullptr, nullptr);
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
+      false, &device_mgr, false, nullptr, nullptr, nullptr);
   TensorHandle* sync_th =
       TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, ctx);
   TensorHandle* async_th = TensorHandle::CreateEmptyLocalHandle(
@@ -106,7 +105,7 @@
     context_ = new EagerContext(
         SessionOptions(),
         tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-        tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
+        /* async= */ false,
         /* lazy_copy_function_remote_inputs= */ false, device_mgr_,
         /* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
         /* custom_kernel_creator= */ nullptr,
@@ -257,9 +256,8 @@
   StaticDeviceMgr local_device_mgr(std::move(d0));
   auto ctx = new EagerContext(
       SessionOptions(),
-      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
-      &local_device_mgr, false, nullptr, nullptr, nullptr);
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
+      false, &local_device_mgr, false, nullptr, nullptr, nullptr);
 
   tensorflow::DataType dtype = DT_RESOURCE;
   TensorShape shape = {2};
@@ -290,9 +288,8 @@
   StaticDeviceMgr local_device_mgr(std::move(d_local));
   auto ctx = new EagerContext(
       SessionOptions(),
-      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
-      &local_device_mgr, false, nullptr, nullptr, nullptr);
+      tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
+      false, &local_device_mgr, false, nullptr, nullptr, nullptr);
 
   std::unique_ptr<Device> d0(
       CreateDevice("CPU", "/job:worker/task:0/device:CPU:0", false));
@@ -346,7 +343,7 @@
     context_ = new EagerContext(
         SessionOptions(),
         tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-        tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
+        /* async= */ false,
         /* lazy_copy_function_remote_inputs= */ false, device_mgr_,
         /* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
         /* custom_kernel_creator= */ nullptr,
@@ -387,7 +384,7 @@
   EagerContext* context = new EagerContext(
       SessionOptions(),
       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
+      /* async= */ false,
       /* lazy_copy_function_remote_inputs= */ false, &device_mgr,
       /* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
       /* custom_kernel_creator= */ nullptr,
diff --git a/tensorflow/core/common_runtime/permuter.cc b/tensorflow/core/common_runtime/permuter.cc
index 45caca0..61b8dcb 100644
--- a/tensorflow/core/common_runtime/permuter.cc
+++ b/tensorflow/core/common_runtime/permuter.cc
@@ -39,17 +39,14 @@
 Permuter::Permuter()
     : col_ctx_(nullptr), col_params_(nullptr), done_(nullptr), counter_(0) {}
 
-bool Permuter::CheckCounter() {
-  mutex_lock lock(mu_counter_);
-  ++counter_;
-  if (counter_ == 2) return true;
-  return false;
-}
-
-StatusCallback Permuter::HalfDone() {
+StatusCallback Permuter::CheckCounterAndCallDone() {
   return [this](const Status& s) {
+    mu_.lock();
     status_.Update(s);
-    if (CheckCounter()) done_(status_);
+    int counter = ++counter_;
+    Status status = status_;
+    mu_.unlock();
+    if (counter == 2) done_(status);
   };
 }
 
@@ -71,11 +68,11 @@
   done_ = std::move(done);
   DispatchSend(col_params_->default_rank,
                col_params_->instance.permutation[col_params_->default_rank],
-               col_ctx_->input, HalfDone());
+               col_ctx_->input, CheckCounterAndCallDone());
   for (int i = 0; i < col_params_->instance.permutation.size(); ++i) {
     if (col_params_->default_rank == col_params_->instance.permutation[i]) {
       DispatchRecv(i, col_params_->instance.permutation[i], col_ctx_->output,
-                   HalfDone());
+                   CheckCounterAndCallDone());
     }
   }
 }
diff --git a/tensorflow/core/common_runtime/permuter.h b/tensorflow/core/common_runtime/permuter.h
index 245168b..a99b848 100644
--- a/tensorflow/core/common_runtime/permuter.h
+++ b/tensorflow/core/common_runtime/permuter.h
@@ -67,9 +67,9 @@
   std::shared_ptr<CollectiveContext> col_ctx_;
   const CollectiveParams* col_params_;  // Not owned
   StatusCallback done_;
-  Status status_;
-  mutex mu_counter_;
-  int counter_ TF_GUARDED_BY(mu_counter_);
+  mutex mu_;
+  Status status_ TF_GUARDED_BY(mu_);
+  int counter_ TF_GUARDED_BY(mu_);
 
   void DispatchSend(int src_rank, int target_rank, const Tensor* tensor,
                     const StatusCallback& done);
@@ -77,12 +77,10 @@
   void DispatchRecv(int src_rank, int target_rank, Tensor* tensor,
                     const StatusCallback& done);
 
-  // Checks if counter_ reaches 2.
   // Atomically increments counter_ by one for sending, one for receiving.
-  // The purpose of this check is to ensure that done_ is called only once.
-  bool CheckCounter();
-
-  StatusCallback HalfDone();
+  // Invokes done when counter_ reaches 2.
+  // The purpose of checking counter_ is to ensure that done_ is called once.
+  StatusCallback CheckCounterAndCallDone();
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index 35971e3..41a3815 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -50,6 +50,78 @@
 )
 
 cc_library(
+    name = "credentials_factory",
+    srcs = ["credentials_factory.cc"],
+    hdrs = ["credentials_factory.h"],
+    deps = [
+        "//tensorflow/core:lib",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
+        tf_grpc_cc_dependency(),
+    ],
+)
+
+tf_cc_test(
+    name = "credentials_factory_test",
+    srcs = ["credentials_factory_test.cc"],
+    deps = [
+        ":credentials_factory",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
+cc_library(
+    name = "data_service",
+    srcs = ["data_service.cc"],
+    hdrs = [
+        "data_service.h",
+    ],
+    deps = [
+        ":credentials_factory",
+        ":dispatcher_cc_grpc_proto",
+        ":dispatcher_proto_cc",
+        ":grpc_util",
+        ":worker_cc_grpc_proto",
+        ":worker_proto_cc",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        tf_grpc_cc_dependency(),
+    ],
+)
+
+tf_cc_test(
+    name = "data_service_test",
+    srcs = ["data_service_test.cc"],
+    tags = ["no_windows"],
+    deps = [
+        ":data_service",
+        ":dispatcher_cc_grpc_proto",
+        ":dispatcher_proto_cc",
+        ":grpc_dispatcher_impl",
+        ":grpc_util",
+        ":grpc_worker_impl",
+        ":local_credentials_factory",
+        ":server_lib",
+        ":test_cluster",
+        ":test_util",
+        ":worker_cc_grpc_proto",
+        ":worker_proto_cc",
+        "@com_google_absl//absl/strings",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core/data:compression_utils",
+        "//tensorflow/core/kernels/data:dataset_test_base",
+        tf_grpc_cc_dependency(),
+    ] + tf_protos_profiler_service(),
+)
+
+cc_library(
     name = "dataset_store",
     srcs = ["dataset_store.cc"],
     hdrs = ["dataset_store.h"],
@@ -78,6 +150,14 @@
     ],
 )
 
+cc_grpc_library(
+    name = "dispatcher_cc_grpc_proto",
+    srcs = [":dispatcher_proto"],
+    generate_mocks = True,
+    grpc_only = True,
+    deps = [":dispatcher_proto_cc"],
+)
+
 cc_library(
     name = "dispatcher_impl",
     srcs = ["dispatcher_impl.cc"],
@@ -141,32 +221,14 @@
 )
 
 cc_library(
-    name = "worker_impl",
-    srcs = ["worker_impl.cc"],
-    hdrs = [
-        "worker_impl.h",
-    ],
+    name = "grpc_dispatcher_impl",
+    srcs = ["grpc_dispatcher_impl.cc"],
+    hdrs = ["grpc_dispatcher_impl.h"],
     deps = [
-        ":common_proto_cc",
-        ":credentials_factory",
-        ":data_service",
         ":dispatcher_cc_grpc_proto",
-        ":dispatcher_proto_cc",
-        ":grpc_util",
-        ":utils",
-        ":worker_proto_cc",
-        "//tensorflow/c:c_api_internal",
-        "//tensorflow/c:tf_status_helper",
-        "//tensorflow/core:core_cpu",
-        "//tensorflow/core:core_cpu_internal",
-        "//tensorflow/core:framework_internal",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
+        ":dispatcher_impl",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/data:dataset_proto_cc",
-        "//tensorflow/core/data:standalone",
-        "@com_google_absl//absl/container:flat_hash_map",
-        "@com_google_absl//absl/memory",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
         tf_grpc_cc_dependency(),
     ],
 )
@@ -197,6 +259,19 @@
 )
 
 cc_library(
+    name = "grpc_worker_impl",
+    srcs = ["grpc_worker_impl.cc"],
+    hdrs = ["grpc_worker_impl.h"],
+    deps = [
+        ":worker_cc_grpc_proto",
+        ":worker_impl",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+        tf_grpc_cc_dependency(),
+    ],
+)
+
+cc_library(
     name = "journal",
     srcs = ["journal.cc"],
     hdrs = ["journal.h"],
@@ -209,6 +284,15 @@
     ],
 )
 
+tf_proto_library(
+    name = "journal_proto",
+    srcs = ["journal.proto"],
+    cc_api_version = 2,
+    protodeps = [
+        ":common_proto",
+    ],
+)
+
 tf_cc_test(
     name = "journal_test",
     srcs = ["journal_test.cc"],
@@ -224,50 +308,51 @@
     ],
 )
 
-tf_proto_library(
-    name = "journal_proto",
-    srcs = ["journal.proto"],
-    cc_api_version = 2,
-    protodeps = [
-        ":common_proto",
-    ],
-)
-
-cc_library(
-    name = "credentials_factory",
-    srcs = ["credentials_factory.cc"],
-    hdrs = ["credentials_factory.h"],
-    deps = [
-        "//tensorflow/core:lib",
-        "@com_google_absl//absl/strings",
-        tf_grpc_cc_dependency(),
-    ],
-)
-
-tf_cc_test(
-    name = "credentials_factory_test",
-    srcs = ["credentials_factory_test.cc"],
-    deps = [
-        ":credentials_factory",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
-    ],
-)
-
 # Link this target to enable LOCAL credentials for the dataset service.
 cc_library(
     name = "local_credentials_factory",
     srcs = ["local_credentials_factory.cc"],
     deps = [
         ":credentials_factory",
+        "@com_google_absl//absl/memory",
         tf_grpc_cc_dependency(),
     ],
     alwayslink = 1,
 )
 
 cc_library(
+    name = "server_lib",
+    srcs = ["server_lib.cc"],
+    hdrs = ["server_lib.h"],
+    linkstatic = True,
+    visibility = [
+        "//visibility:public",
+    ],
+    deps = [
+        ":credentials_factory",
+        ":grpc_dispatcher_impl",
+        ":grpc_util",
+        ":grpc_worker_impl",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:tensorflow",
+        "//tensorflow/core/profiler/rpc:profiler_service_impl",
+        tf_grpc_cc_dependency(),
+    ],
+    alwayslink = 1,
+)
+
+# This needs to be cc_header_only_library - tf_pybind_cc_library_wrapper
+# does not pull in the server_lib.h header.
+cc_header_only_library(
+    name = "server_lib_headers_lib",
+    features = ["-parse_headers"],
+    deps = [
+        ":server_lib",
+    ],
+)
+
+cc_library(
     name = "test_cluster",
     testonly = True,
     srcs = ["test_cluster.cc"],
@@ -311,64 +396,6 @@
 )
 
 cc_library(
-    name = "grpc_dispatcher_impl",
-    srcs = ["grpc_dispatcher_impl.cc"],
-    hdrs = ["grpc_dispatcher_impl.h"],
-    deps = [
-        ":dispatcher_cc_grpc_proto",
-        ":dispatcher_impl",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
-        tf_grpc_cc_dependency(),
-    ],
-)
-
-cc_library(
-    name = "grpc_worker_impl",
-    srcs = ["grpc_worker_impl.cc"],
-    hdrs = ["grpc_worker_impl.h"],
-    deps = [
-        ":worker_cc_grpc_proto",
-        ":worker_impl",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
-        tf_grpc_cc_dependency(),
-    ],
-)
-
-# This needs to be cc_header_only_library - tf_pybind_cc_library_wrapper
-# does not pull in the server_lib.h header.
-cc_header_only_library(
-    name = "server_lib_headers_lib",
-    features = ["-parse_headers"],
-    deps = [
-        ":server_lib",
-    ],
-)
-
-cc_library(
-    name = "server_lib",
-    srcs = ["server_lib.cc"],
-    hdrs = ["server_lib.h"],
-    linkstatic = True,
-    visibility = [
-        "//visibility:public",
-    ],
-    deps = [
-        ":credentials_factory",
-        ":grpc_dispatcher_impl",
-        ":grpc_util",
-        ":grpc_worker_impl",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:tensorflow",
-        "//tensorflow/core/profiler/rpc:profiler_service_impl",
-        tf_grpc_cc_dependency(),
-    ],
-    alwayslink = 1,
-)
-
-cc_library(
     name = "utils",
     srcs = ["utils.cc"],
     hdrs = ["utils.h"],
@@ -394,62 +421,6 @@
     ],
 )
 
-cc_library(
-    name = "data_service",
-    srcs = ["data_service.cc"],
-    hdrs = [
-        "data_service.h",
-    ],
-    deps = [
-        ":credentials_factory",
-        ":dispatcher_cc_grpc_proto",
-        ":dispatcher_proto_cc",
-        ":grpc_util",
-        ":worker_cc_grpc_proto",
-        ":worker_proto_cc",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//tensorflow/core:protos_all_cc",
-        tf_grpc_cc_dependency(),
-    ],
-)
-
-tf_cc_test(
-    name = "data_service_test",
-    srcs = ["data_service_test.cc"],
-    tags = ["no_windows"],
-    deps = [
-        ":data_service",
-        ":dispatcher_cc_grpc_proto",
-        ":dispatcher_proto_cc",
-        ":grpc_dispatcher_impl",
-        ":grpc_util",
-        ":grpc_worker_impl",
-        ":local_credentials_factory",
-        ":server_lib",
-        ":test_cluster",
-        ":test_util",
-        ":worker_cc_grpc_proto",
-        ":worker_proto_cc",
-        "@com_google_absl//absl/strings",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core/data:compression_utils",
-        "//tensorflow/core/kernels/data:dataset_test_base",
-        tf_grpc_cc_dependency(),
-    ] + tf_protos_profiler_service(),
-)
-
-cc_grpc_library(
-    name = "dispatcher_cc_grpc_proto",
-    srcs = [":dispatcher_proto"],
-    generate_mocks = True,
-    grpc_only = True,
-    deps = [":dispatcher_proto_cc"],
-)
-
 cc_grpc_library(
     name = "worker_cc_grpc_proto",
     srcs = [":worker_proto"],
@@ -457,3 +428,34 @@
     grpc_only = True,
     deps = [":worker_proto_cc"],
 )
+
+cc_library(
+    name = "worker_impl",
+    srcs = ["worker_impl.cc"],
+    hdrs = [
+        "worker_impl.h",
+    ],
+    deps = [
+        ":common_proto_cc",
+        ":credentials_factory",
+        ":data_service",
+        ":dispatcher_cc_grpc_proto",
+        ":dispatcher_proto_cc",
+        ":grpc_util",
+        ":utils",
+        ":worker_proto_cc",
+        "//tensorflow/c:c_api_internal",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/data:dataset_proto_cc",
+        "//tensorflow/core/data:standalone",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/memory",
+        tf_grpc_cc_dependency(),
+    ],
+)
diff --git a/tensorflow/core/data/service/credentials_factory.cc b/tensorflow/core/data/service/credentials_factory.cc
index 43b56d5..e7a5177 100644
--- a/tensorflow/core/data/service/credentials_factory.cc
+++ b/tensorflow/core/data/service/credentials_factory.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/data/service/credentials_factory.h"
 
+#include "absl/memory/memory.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/mutex.h"
 
@@ -35,9 +36,11 @@
 }
 }  // namespace
 
-void CredentialsFactory::Register(CredentialsFactory* factory) {
+void CredentialsFactory::Register(std::unique_ptr<CredentialsFactory> factory) {
   mutex_lock l(*get_lock());
-  if (!credentials_factories().insert({factory->Protocol(), factory}).second) {
+  if (!credentials_factories()
+           .insert({factory->Protocol(), factory.release()})
+           .second) {
     LOG(ERROR)
         << "Two credentials factories are being registered with protocol "
         << factory->Protocol() << ". Which one gets used is undefined.";
@@ -45,11 +48,11 @@
 }
 
 Status CredentialsFactory::Get(absl::string_view protocol,
-                               CredentialsFactory** out) {
+                               CredentialsFactory*& out) {
   mutex_lock l(*get_lock());
   auto it = credentials_factories().find(std::string(protocol));
   if (it != credentials_factories().end()) {
-    *out = it->second;
+    out = it->second;
     return Status::OK();
   }
 
@@ -66,18 +69,18 @@
 
 Status CredentialsFactory::CreateServerCredentials(
     absl::string_view protocol,
-    std::shared_ptr<::grpc::ServerCredentials>* out) {
+    std::shared_ptr<::grpc::ServerCredentials>& out) {
   CredentialsFactory* factory;
-  TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
+  TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, factory));
   TF_RETURN_IF_ERROR(factory->CreateServerCredentials(out));
   return Status::OK();
 }
 
 Status CredentialsFactory::CreateClientCredentials(
     absl::string_view protocol,
-    std::shared_ptr<::grpc::ChannelCredentials>* out) {
+    std::shared_ptr<::grpc::ChannelCredentials>& out) {
   CredentialsFactory* factory;
-  TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
+  TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, factory));
   TF_RETURN_IF_ERROR(factory->CreateClientCredentials(out));
   return Status::OK();
 }
@@ -87,14 +90,14 @@
   std::string Protocol() override { return "grpc"; }
 
   Status CreateServerCredentials(
-      std::shared_ptr<::grpc::ServerCredentials>* out) override {
-    *out = ::grpc::InsecureServerCredentials();
+      std::shared_ptr<::grpc::ServerCredentials>& out) override {
+    out = ::grpc::InsecureServerCredentials();
     return Status::OK();
   }
 
   Status CreateClientCredentials(
-      std::shared_ptr<::grpc::ChannelCredentials>* out) override {
-    *out = ::grpc::InsecureChannelCredentials();
+      std::shared_ptr<::grpc::ChannelCredentials>& out) override {
+    out = ::grpc::InsecureChannelCredentials();
     return Status::OK();
   }
 };
@@ -102,8 +105,8 @@
 class InsecureCredentialsRegistrar {
  public:
   InsecureCredentialsRegistrar() {
-    auto factory = new InsecureCredentialsFactory();
-    CredentialsFactory::Register(factory);
+    CredentialsFactory::Register(
+        absl::make_unique<InsecureCredentialsFactory>());
   }
 };
 static InsecureCredentialsRegistrar registrar;
diff --git a/tensorflow/core/data/service/credentials_factory.h b/tensorflow/core/data/service/credentials_factory.h
index 2407f64..754ed54 100644
--- a/tensorflow/core/data/service/credentials_factory.h
+++ b/tensorflow/core/data/service/credentials_factory.h
@@ -34,33 +34,33 @@
   // look up with `GetCredentials` to find the registered credentials factory.
   virtual std::string Protocol() = 0;
 
-  // Stores server credentials to `*out`.
+  // Stores server credentials to `out`.
   virtual Status CreateServerCredentials(
-      std::shared_ptr<::grpc::ServerCredentials>* out) = 0;
+      std::shared_ptr<::grpc::ServerCredentials>& out) = 0;
 
-  // Stores client credentials to `*out`.
+  // Stores client credentials to `out`.
   virtual Status CreateClientCredentials(
-      std::shared_ptr<::grpc::ChannelCredentials>* out) = 0;
+      std::shared_ptr<::grpc::ChannelCredentials>& out) = 0;
 
   // Registers a credentials factory.
-  static void Register(CredentialsFactory* factory);
+  static void Register(std::unique_ptr<CredentialsFactory> factory);
 
   // Creates server credentials using the credentials factory registered as
-  // `protocol`, and stores them to `*out`.
+  // `protocol`, and stores them to `out`.
   static Status CreateServerCredentials(
       absl::string_view protocol,
-      std::shared_ptr<::grpc::ServerCredentials>* out);
+      std::shared_ptr<::grpc::ServerCredentials>& out);
 
   // Creates client credentials using the credentials factory registered as
-  // `protocol`, and stores them to `*out`.
+  // `protocol`, and stores them to `out`.
   static Status CreateClientCredentials(
       absl::string_view protocol,
-      std::shared_ptr<::grpc::ChannelCredentials>* out);
+      std::shared_ptr<::grpc::ChannelCredentials>& out);
 
  private:
-  // Gets the credentials factory registered via `Register` for the specified
-  // protocol, and stores it to `*out`.
-  static Status Get(const absl::string_view protocol, CredentialsFactory** out);
+  // Borrows a pointer to the credentials factory registered via `Register`
+  // for the specified protocol, and stores it to `out`.
+  static Status Get(const absl::string_view protocol, CredentialsFactory*& out);
 };
 
 }  // namespace data
diff --git a/tensorflow/core/data/service/credentials_factory_test.cc b/tensorflow/core/data/service/credentials_factory_test.cc
index 507c553..fbaed58 100644
--- a/tensorflow/core/data/service/credentials_factory_test.cc
+++ b/tensorflow/core/data/service/credentials_factory_test.cc
@@ -32,43 +32,44 @@
   std::string Protocol() override { return "test"; }
 
   Status CreateServerCredentials(
-      std::shared_ptr<grpc::ServerCredentials>* out) override {
+      std::shared_ptr<grpc::ServerCredentials>& out) override {
     return errors::Internal(kFailedToCreateServerCredentials);
   }
 
   Status CreateClientCredentials(
-      std::shared_ptr<grpc::ChannelCredentials>* out) override {
+      std::shared_ptr<grpc::ChannelCredentials>& out) override {
     return errors::Internal(kFailedToCreateClientCredentials);
   }
 };
 }  // namespace
 
 TEST(CredentialsFactory, Register) {
-  TestCredentialsFactory test_factory;
-  CredentialsFactory::Register(&test_factory);
+  auto test_factory = absl::make_unique<TestCredentialsFactory>();
+  std::string protocol = test_factory->Protocol();
+  CredentialsFactory::Register(std::move(test_factory));
   std::shared_ptr<grpc::ServerCredentials> server_credentials;
   ASSERT_EQ(errors::Internal(kFailedToCreateServerCredentials),
-            CredentialsFactory::CreateServerCredentials(test_factory.Protocol(),
-                                                        &server_credentials));
+            CredentialsFactory::CreateServerCredentials(protocol,
+                                                        server_credentials));
   std::shared_ptr<grpc::ChannelCredentials> client_credentials;
   ASSERT_EQ(errors::Internal(kFailedToCreateClientCredentials),
-            CredentialsFactory::CreateClientCredentials(test_factory.Protocol(),
-                                                        &client_credentials));
+            CredentialsFactory::CreateClientCredentials(protocol,
+                                                        client_credentials));
 }
 
 TEST(CredentialsFactory, DefaultGrpcProtocol) {
   std::shared_ptr<grpc::ServerCredentials> server_credentials;
   TF_ASSERT_OK(
-      CredentialsFactory::CreateServerCredentials("grpc", &server_credentials));
+      CredentialsFactory::CreateServerCredentials("grpc", server_credentials));
   std::shared_ptr<grpc::ChannelCredentials> client_credentials;
   TF_ASSERT_OK(
-      CredentialsFactory::CreateClientCredentials("grpc", &client_credentials));
+      CredentialsFactory::CreateClientCredentials("grpc", client_credentials));
 }
 
 TEST(CredentialsFactory, MissingServerProtocol) {
   std::shared_ptr<grpc::ServerCredentials> server_credentials;
   Status s = CredentialsFactory::CreateServerCredentials("unknown_protocol",
-                                                         &server_credentials);
+                                                         server_credentials);
   ASSERT_EQ(error::Code::NOT_FOUND, s.code());
   ASSERT_TRUE(
       absl::StrContains(s.ToString(),
@@ -79,7 +80,7 @@
 TEST(CredentialsFactory, MissingClientProtocol) {
   std::shared_ptr<grpc::ChannelCredentials> client_credentials;
   Status s = CredentialsFactory::CreateClientCredentials("unknown_protocol",
-                                                         &client_credentials);
+                                                         client_credentials);
   ASSERT_EQ(error::Code::NOT_FOUND, s.code());
   ASSERT_TRUE(
       absl::StrContains(s.ToString(),
diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc
index 0f25805..c64daa6 100644
--- a/tensorflow/core/data/service/data_service.cc
+++ b/tensorflow/core/data/service/data_service.cc
@@ -31,11 +31,11 @@
 constexpr const char kOneEpoch[] = "one_epoch";
 }  // namespace
 
-Status ParseProcessingMode(const std::string& s, ProcessingMode* mode) {
+Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
   if (s == kParallelEpochs) {
-    *mode = ProcessingMode::PARALLEL_EPOCHS;
+    mode = ProcessingMode::PARALLEL_EPOCHS;
   } else if (s == kOneEpoch) {
-    *mode = ProcessingMode::ONE_EPOCH;
+    mode = ProcessingMode::ONE_EPOCH;
   } else {
     return errors::InvalidArgument("Unrecognized processing mode: ", s);
   }
@@ -105,7 +105,7 @@
 }
 
 Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
-                                                    int64* dataset_id) {
+                                                    int64& dataset_id) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   GetOrRegisterDatasetRequest req;
   *req.mutable_dataset()->mutable_graph() = dataset;
@@ -115,13 +115,13 @@
   if (!status.ok()) {
     return grpc_util::WrapError("Failed to register dataset", status);
   }
-  *dataset_id = resp.dataset_id();
+  dataset_id = resp.dataset_id();
   return Status::OK();
 }
 
 Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
                                               ProcessingMode processing_mode,
-                                              int64* job_client_id) {
+                                              int64& job_client_id) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   CreateJobRequest req;
   req.set_dataset_id(dataset_id);
@@ -134,13 +134,13 @@
         absl::StrCat("Failed to create job for dataset with id ", dataset_id),
         status);
   }
-  *job_client_id = resp.job_client_id();
+  job_client_id = resp.job_client_id();
   return Status::OK();
 }
 
 Status DataServiceDispatcherClient::GetOrCreateJob(
     int64 dataset_id, ProcessingMode processing_mode,
-    const std::string& job_name, int job_name_index, int64* job_client_id) {
+    const std::string& job_name, int job_name_index, int64& job_client_id) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   GetOrCreateJobRequest req;
   req.set_dataset_id(dataset_id);
@@ -156,7 +156,7 @@
                      dataset_id),
         status);
   }
-  *job_client_id = resp.job_client_id();
+  job_client_id = resp.job_client_id();
   return Status::OK();
 }
 
@@ -176,8 +176,8 @@
 }
 
 Status DataServiceDispatcherClient::GetTasks(int64 job_client_id,
-                                             std::vector<TaskInfo>* tasks,
-                                             bool* job_finished) {
+                                             std::vector<TaskInfo>& tasks,
+                                             bool& job_finished) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   GetTasksRequest req;
   req.set_job_client_id(job_client_id);
@@ -187,16 +187,16 @@
   if (!s.ok()) {
     return grpc_util::WrapError("Failed to get tasks", s);
   }
-  tasks->clear();
+  tasks.clear();
   for (auto& task : resp.task_info()) {
-    tasks->push_back(task);
+    tasks.push_back(task);
   }
-  *job_finished = resp.job_finished();
+  job_finished = resp.job_finished();
   return Status::OK();
 }
 
 Status DataServiceDispatcherClient::GetWorkers(
-    std::vector<WorkerInfo>* workers) {
+    std::vector<WorkerInfo>& workers) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   GetWorkersRequest req;
   GetWorkersResponse resp;
@@ -205,9 +205,9 @@
   if (!s.ok()) {
     return grpc_util::WrapError("Failed to get workers", s);
   }
-  workers->clear();
+  workers.clear();
   for (auto& worker : resp.workers()) {
-    workers->push_back(worker);
+    workers.push_back(worker);
   }
   return Status::OK();
 }
@@ -219,15 +219,15 @@
   }
   std::shared_ptr<grpc::ChannelCredentials> credentials;
   TF_RETURN_IF_ERROR(
-      CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
+      CredentialsFactory::CreateClientCredentials(protocol_, credentials));
   auto channel = grpc::CreateChannel(address_, credentials);
   stub_ = DispatcherService::NewStub(channel);
   return Status::OK();
 }
 
 Status DataServiceWorkerClient::GetElement(int64 task_id,
-                                           CompressedElement* element,
-                                           bool* end_of_sequence) {
+                                           CompressedElement& element,
+                                           bool& end_of_sequence) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   GetElementRequest req;
   req.set_task_id(task_id);
@@ -237,9 +237,9 @@
   if (!s.ok()) {
     return grpc_util::WrapError("Failed to get element", s);
   }
-  *end_of_sequence = resp.end_of_sequence();
-  if (!*end_of_sequence) {
-    *element = std::move(*resp.mutable_compressed_element());
+  end_of_sequence = resp.end_of_sequence();
+  if (!end_of_sequence) {
+    element = std::move(*resp.mutable_compressed_element());
   }
   return Status::OK();
 }
@@ -251,7 +251,7 @@
   }
   std::shared_ptr<grpc::ChannelCredentials> credentials;
   TF_RETURN_IF_ERROR(
-      CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
+      CredentialsFactory::CreateClientCredentials(protocol_, credentials));
   grpc::ChannelArguments args;
   args.SetMaxReceiveMessageSize(-1);
   auto channel = grpc::CreateCustomChannel(address_, credentials, args);
@@ -261,20 +261,20 @@
 
 Status CreateDataServiceDispatcherClient(
     const std::string& address, const std::string& protocol,
-    std::unique_ptr<DataServiceDispatcherClient>* out) {
+    std::unique_ptr<DataServiceDispatcherClient>& out) {
   auto client =
       absl::make_unique<DataServiceDispatcherClient>(address, protocol);
   TF_RETURN_IF_ERROR(client->Initialize());
-  *out = std::move(client);
+  out = std::move(client);
   return Status::OK();
 }
 
 Status CreateDataServiceWorkerClient(
     const std::string& address, const std::string& protocol,
-    std::unique_ptr<DataServiceWorkerClient>* out) {
+    std::unique_ptr<DataServiceWorkerClient>& out) {
   auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
   TF_RETURN_IF_ERROR(client->Initialize());
-  *out = std::move(client);
+  out = std::move(client);
   return Status::OK();
 }
 }  // namespace data
diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h
index 621e76d..c5eb6a3 100644
--- a/tensorflow/core/data/service/data_service.h
+++ b/tensorflow/core/data/service/data_service.h
@@ -34,8 +34,8 @@
 };
 
 // Parses a string representing a processing mode and stores the result in
-// *mode. Returns an InvalidArgument status if the string is not recognized.
-Status ParseProcessingMode(const std::string& s, ProcessingMode* mode);
+// `mode`. Returns an InvalidArgument status if the string is not recognized.
+Status ParseProcessingMode(const std::string& s, ProcessingMode& mode);
 
 // Converts a processing mode to its corresponding string.
 std::string ProcessingModeToString(ProcessingMode mode);
@@ -87,34 +87,34 @@
   Status GetDatasetDef(int64 dataset_id, DatasetDef& dataset_def);
 
   // Registers a dataset with the tf.data service, and stores the generated
-  // dataset id in `*dataset_id`.
-  Status RegisterDataset(GraphDef dataset, int64* dataset_id);
+  // dataset id in `dataset_id`.
+  Status RegisterDataset(GraphDef dataset, int64& dataset_id);
 
   // Creates a new tf.data service job for the specified dataset. The id for the
-  // created job will be stored in `*job_client_id`.
+  // created job will be stored in `job_client_id`.
   Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
-                   int64* job_client_id);
+                   int64& job_client_id);
 
   // Gets the job id for the job represented by the tuple
-  // (job_name, job_name_index), and stores the id in *job_client_id. If the
+  // (job_name, job_name_index), and stores the id in `job_client_id`. If the
   // job doesn't exist yet, it will be created.
   Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
                         const std::string& job_name, int job_name_index,
-                        int64* job_client_id);
+                        int64& job_client_id);
 
   // Releases a job client id, indicating that the id will no longer be used to
   // read from the job.
   Status ReleaseJobClient(int64 job_client_id);
 
   // Queries the dispatcher for the tasks associated with the specified job.
-  // The tasks will be stored in *tasks, and whether the job is finished will
-  // be stored in `*job_finished`.
-  Status GetTasks(int64 job_client_id, std::vector<TaskInfo>* tasks,
-                  bool* job_finished);
+  // The tasks will be stored in `tasks`, and whether the job is finished will
+  // be stored in `job_finished`.
+  Status GetTasks(int64 job_client_id, std::vector<TaskInfo>& tasks,
+                  bool& job_finished);
 
   // Queries the dispatcher for its registered workers. The worker info will be
-  // stored in `*workers`.
-  Status GetWorkers(std::vector<WorkerInfo>* workers);
+  // stored in `workers`.
+  Status GetWorkers(std::vector<WorkerInfo>& workers);
 
  protected:
   Status EnsureInitialized() override;
@@ -134,10 +134,10 @@
       : DataServiceClientBase(address, protocol) {}
 
   // Fetches the next element for the specified task_id. The element's
-  // compressed tensors will be stored in *element. If no element is available,
-  // `*end_of_sequence` will be `true`, and `element` will be left unchanged.
-  Status GetElement(int64 task_id, CompressedElement* element,
-                    bool* end_of_sequence);
+  // compressed tensors will be stored in `element`. If no element is available,
+  // `end_of_sequence` will be `true`, and `element` will be left unchanged.
+  Status GetElement(int64 task_id, CompressedElement& element,
+                    bool& end_of_sequence);
 
  protected:
   Status EnsureInitialized() override;
@@ -152,12 +152,12 @@
 // Creates and initializes a new tf.data service dispatcher client.
 Status CreateDataServiceDispatcherClient(
     const std::string& address, const std::string& protocol,
-    std::unique_ptr<DataServiceDispatcherClient>* out);
+    std::unique_ptr<DataServiceDispatcherClient>& out);
 
 // Creates and initializes a new tf.data service worker client.
 Status CreateDataServiceWorkerClient(
     const std::string& address, const std::string& protocol,
-    std::unique_ptr<DataServiceWorkerClient>* out);
+    std::unique_ptr<DataServiceWorkerClient>& out);
 
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc
index 6075700..7b7e240 100644
--- a/tensorflow/core/data/service/data_service_test.cc
+++ b/tensorflow/core/data/service/data_service_test.cc
@@ -41,19 +41,19 @@
 
 TEST(DataService, ParseParallelEpochsProcessingMode) {
   ProcessingMode mode;
-  TF_ASSERT_OK(ParseProcessingMode("parallel_epochs", &mode));
+  TF_ASSERT_OK(ParseProcessingMode("parallel_epochs", mode));
   EXPECT_EQ(mode, ProcessingMode::PARALLEL_EPOCHS);
 }
 
 TEST(DataService, ParseOneEpochProcessingMode) {
   ProcessingMode mode;
-  TF_ASSERT_OK(ParseProcessingMode("one_epoch", &mode));
+  TF_ASSERT_OK(ParseProcessingMode("one_epoch", mode));
   EXPECT_EQ(mode, ProcessingMode::ONE_EPOCH);
 }
 
 TEST(DataService, ParseInvalidProcessingMode) {
   ProcessingMode mode;
-  Status s = ParseProcessingMode("invalid", &mode);
+  Status s = ParseProcessingMode("invalid", mode);
   EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
 }
 
@@ -69,7 +69,7 @@
   DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(),
                                          kProtocol);
   std::vector<WorkerInfo> workers;
-  TF_EXPECT_OK(dispatcher.GetWorkers(&workers));
+  TF_EXPECT_OK(dispatcher.GetWorkers(workers));
   EXPECT_EQ(1, workers.size());
 }
 
diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc
index de5f63a..dcd5cb5 100644
--- a/tensorflow/core/data/service/dispatcher_impl.cc
+++ b/tensorflow/core/data/service/dispatcher_impl.cc
@@ -71,14 +71,14 @@
 }
 
 Status CreateWorkerStub(const std::string& address, const std::string& protocol,
-                        std::unique_ptr<WorkerService::Stub>* stub) {
+                        std::unique_ptr<WorkerService::Stub>& stub) {
   ::grpc::ChannelArguments args;
   args.SetMaxReceiveMessageSize(-1);
   std::shared_ptr<::grpc::ChannelCredentials> credentials;
   TF_RETURN_IF_ERROR(
-      CredentialsFactory::CreateClientCredentials(protocol, &credentials));
+      CredentialsFactory::CreateClientCredentials(protocol, credentials));
   auto channel = ::grpc::CreateCustomChannel(address, credentials, args);
-  *stub = WorkerService::NewStub(channel);
+  stub = WorkerService::NewStub(channel);
   return Status::OK();
 }
 }  // namespace
@@ -117,7 +117,7 @@
   Update update;
   bool end_of_journal = false;
   FileJournalReader reader(Env::Default(), JournalDir(config_.work_dir()));
-  Status s = reader.Read(&update, &end_of_journal);
+  Status s = reader.Read(update, end_of_journal);
   if (errors::IsNotFound(s)) {
     LOG(INFO) << "No journal found. Starting dispatcher from new state.";
   } else if (!s.ok()) {
@@ -125,7 +125,7 @@
   } else {
     while (!end_of_journal) {
       TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update));
-      TF_RETURN_IF_ERROR(reader.Read(&update, &end_of_journal));
+      TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
     }
   }
   // Initialize the journal writer in `Start` so that we fail fast in case it
@@ -168,11 +168,11 @@
     if (it != tasks_by_job.end()) {
       task = it->second;
     } else {
-      TF_RETURN_IF_ERROR(CreateTask(job, worker_address, &task));
+      TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
     }
     TaskDef* task_def = response->add_tasks();
     std::shared_ptr<const Dataset> dataset;
-    TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, &dataset));
+    TF_RETURN_IF_ERROR(state_.DatasetFromId(job->dataset_id, dataset));
     std::string dataset_key =
         DatasetKey(dataset->dataset_id, dataset->fingerprint);
     if (config_.work_dir().empty()) {
@@ -199,7 +199,7 @@
   for (auto& update : request->updates()) {
     int64 task_id = update.task_id();
     std::shared_ptr<const Task> task;
-    TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, &task));
+    TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
     if (update.completed()) {
       if (task->finished) {
         VLOG(1) << "Received completion update for already-finished task "
@@ -220,7 +220,7 @@
     const GetDatasetDefRequest* request, GetDatasetDefResponse* response) {
   mutex_lock l(mu_);
   std::shared_ptr<const Dataset> dataset;
-  TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), &dataset));
+  TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset));
   std::string key = DatasetKey(dataset->dataset_id, dataset->fingerprint);
   std::shared_ptr<const DatasetDef> dataset_def;
   TF_RETURN_IF_ERROR(dataset_store_->Get(key, dataset_def));
@@ -242,7 +242,7 @@
   VLOG(4) << "Registering dataset graph: " << graph.DebugString();
 #endif
   std::shared_ptr<const Dataset> dataset;
-  Status s = state_.DatasetFromFingerprint(fingerprint, &dataset);
+  Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
   if (s.ok()) {
     int64 id = dataset->dataset_id;
     VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
@@ -254,7 +254,7 @@
   }
 
   int64 id;
-  TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), &id));
+  TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), id));
   response->set_dataset_id(id);
   VLOG(3) << "Registered new dataset with id " << id;
   return Status::OK();
@@ -262,15 +262,15 @@
 
 Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
                                                   const DatasetDef& dataset,
-                                                  int64* dataset_id)
+                                                  int64& dataset_id)
     EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-  *dataset_id = state_.NextAvailableDatasetId();
+  dataset_id = state_.NextAvailableDatasetId();
   Update update;
   RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
-  register_dataset->set_dataset_id(*dataset_id);
+  register_dataset->set_dataset_id(dataset_id);
   register_dataset->set_fingerprint(fingerprint);
   TF_RETURN_IF_ERROR(
-      dataset_store_->Put(DatasetKey(*dataset_id, fingerprint), dataset));
+      dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
   return Apply(update);
 }
 
@@ -284,11 +284,11 @@
   {
     mutex_lock l(mu_);
     TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode,
-                                 absl::optional<NamedJobKey>(), &job));
+                                 absl::optional<NamedJobKey>(), job));
     int64 job_client_id;
     TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
     response->set_job_client_id(job_client_id);
-    TF_RETURN_IF_ERROR(CreateTasksForJob(job, &tasks));
+    TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
   }
   TF_RETURN_IF_ERROR(AssignTasks(tasks));
 
@@ -309,7 +309,7 @@
   std::vector<std::shared_ptr<const Task>> tasks;
   {
     mutex_lock l(mu_);
-    Status s = state_.NamedJobByKey(key, &job);
+    Status s = state_.NamedJobByKey(key, job);
     if (s.ok()) {
       TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode,
                                              request->dataset_id()));
@@ -323,11 +323,11 @@
       return s;
     }
     TF_RETURN_IF_ERROR(
-        CreateJob(request->dataset_id(), requested_processing_mode, key, &job));
+        CreateJob(request->dataset_id(), requested_processing_mode, key, job));
     int64 job_client_id;
     TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
     response->set_job_client_id(job_client_id);
-    TF_RETURN_IF_ERROR(CreateTasksForJob(job, &tasks));
+    TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
   }
   TF_RETURN_IF_ERROR(AssignTasks(tasks));
   VLOG(3) << "Created job " << job->job_id << " for dataset "
@@ -376,7 +376,7 @@
 
 Status DataServiceDispatcherImpl::CreateJob(
     int64 dataset_id, ProcessingMode processing_mode,
-    absl::optional<NamedJobKey> named_job_key, std::shared_ptr<const Job>* job)
+    absl::optional<NamedJobKey> named_job_key, std::shared_ptr<const Job>& job)
     EXCLUSIVE_LOCKS_REQUIRED(mu_) {
   switch (processing_mode) {
     case ProcessingMode::PARALLEL_EPOCHS:
@@ -421,22 +421,22 @@
 
 Status DataServiceDispatcherImpl::CreateTasksForJob(
     std::shared_ptr<const Job> job,
-    std::vector<std::shared_ptr<const Task>>* tasks)
+    std::vector<std::shared_ptr<const Task>>& tasks)
     EXCLUSIVE_LOCKS_REQUIRED(mu_) {
   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
-  tasks->clear();
-  tasks->reserve(workers.size());
+  tasks.clear();
+  tasks.reserve(workers.size());
   for (const auto& worker : workers) {
     std::shared_ptr<const Task> task;
-    TF_RETURN_IF_ERROR(CreateTask(job, worker->address, &task));
-    tasks->push_back(task);
+    TF_RETURN_IF_ERROR(CreateTask(job, worker->address, task));
+    tasks.push_back(task);
   }
   return Status::OK();
 }
 
 Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
                                              const std::string& worker_address,
-                                             std::shared_ptr<const Task>* task)
+                                             std::shared_ptr<const Task>& task)
     EXCLUSIVE_LOCKS_REQUIRED(mu_) {
   int64 task_id = state_.NextAvailableTaskId();
   Update update;
@@ -459,19 +459,19 @@
 }
 
 Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
-    const std::string& worker_address, WorkerService::Stub** out_stub)
+    const std::string& worker_address, WorkerService::Stub*& out_stub)
     LOCKS_EXCLUDED(mu_) {
   {
     mutex_lock l(mu_);
     auto it = worker_stubs_.find(worker_address);
     if (it != worker_stubs_.end()) {
-      *out_stub = it->second.get();
+      out_stub = it->second.get();
       return Status::OK();
     }
   }
   std::unique_ptr<WorkerService::Stub> stub;
   TF_RETURN_IF_ERROR(
-      CreateWorkerStub(worker_address, config_.protocol(), &stub));
+      CreateWorkerStub(worker_address, config_.protocol(), stub));
   {
     mutex_lock l(mu_);
     // A concurrent call could have already created the stub.
@@ -479,7 +479,7 @@
     if (worker == nullptr) {
       worker = std::move(stub);
     }
-    *out_stub = worker.get();
+    out_stub = worker.get();
   }
   return Status::OK();
 }
@@ -495,7 +495,7 @@
   {
     mutex_lock l(mu_);
     std::shared_ptr<const Dataset> dataset;
-    TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, &dataset));
+    TF_RETURN_IF_ERROR(state_.DatasetFromId(task->dataset_id, dataset));
     std::string dataset_key =
         DatasetKey(dataset->dataset_id, dataset->fingerprint);
     if (config_.work_dir().empty()) {
@@ -511,7 +511,7 @@
   task_def->set_task_id(task->task_id);
   ProcessTaskResponse resp;
   WorkerService::Stub* stub;
-  TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, &stub));
+  TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
   grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
   if (!s.ok()) {
     return grpc_util::WrapError(
@@ -530,7 +530,7 @@
   std::shared_ptr<const Job> job;
   TF_RETURN_IF_ERROR(state_.JobForJobClientId(request->job_client_id(), job));
   std::vector<std::shared_ptr<const Task>> tasks;
-  TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, &tasks));
+  TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
   for (const auto& task : tasks) {
     TaskInfo* task_info = response->mutable_task_info()->Add();
     task_info->set_worker_address(task->worker_address);
diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h
index 34cdc67..2cf341a 100644
--- a/tensorflow/core/data/service/dispatcher_impl.h
+++ b/tensorflow/core/data/service/dispatcher_impl.h
@@ -77,19 +77,20 @@
 
  private:
   // Registers a dataset with the given fingerprint, storing the new dataset's
-  // id in `*dataset-id`.
+  // id in `dataset_id`.
   Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
-                         int64* dataset_id) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+                         int64& dataset_id) EXCLUSIVE_LOCKS_REQUIRED(mu_);
   // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
-  // stub and stores it in `worker_stubs_`.
+  // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is
+  // stored in `out_stub`.
   Status GetOrCreateWorkerStub(const std::string& worker_address,
-                               WorkerService::Stub** out_stub)
+                               WorkerService::Stub*& out_stub)
       LOCKS_EXCLUDED(mu_);
-  // Creates a job and stores it in `*job`. This method updates the
+  // Creates a job and stores it in `job`. This method updates the
   // dispatcher state with the new job, but does not assign tasks to workers.
   Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
                    absl::optional<DispatcherState::NamedJobKey> named_job_key,
-                   std::shared_ptr<const DispatcherState::Job>* job)
+                   std::shared_ptr<const DispatcherState::Job>& job)
       EXCLUSIVE_LOCKS_REQUIRED(mu_);
   // Acquires a job client id to read from the given job and sets
   // `job_client_id`.
@@ -97,17 +98,17 @@
       const std::shared_ptr<const DispatcherState::Job>& job,
       int64& job_client_id) EXCLUSIVE_LOCKS_REQUIRED(mu_);
   // Creates one task for each worker, for the given job. The created tasks are
-  // stored in `*tasks`. This method only updates dispatcher metadata with the
+  // stored in `tasks`. This method only updates dispatcher metadata with the
   // new tasks, but doesn't assign the tasks to the workers.
   Status CreateTasksForJob(
       std::shared_ptr<const DispatcherState::Job> job,
-      std::vector<std::shared_ptr<const DispatcherState::Task>>* tasks)
+      std::vector<std::shared_ptr<const DispatcherState::Task>>& tasks)
       EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
-  // Creates a new task for a job, storing the created task in `*task`.
+  // Creates a new task for a job, storing the created task in `task`.
   Status CreateTask(std::shared_ptr<const DispatcherState::Job> job,
                     const std::string& worker_address,
-                    std::shared_ptr<const DispatcherState::Task>* task);
+                    std::shared_ptr<const DispatcherState::Task>& task);
   // Assigns the list of tasks to the workers indicated by their
   // `worker_address` fields.
   Status AssignTasks(
diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc
index b302810..3afee88 100644
--- a/tensorflow/core/data/service/dispatcher_state.cc
+++ b/tensorflow/core/data/service/dispatcher_state.cc
@@ -25,7 +25,7 @@
 
 DispatcherState::DispatcherState() {}
 
-Status DispatcherState::Apply(Update update) {
+Status DispatcherState::Apply(const Update& update) {
   switch (update.update_type_case()) {
     case Update::kRegisterDataset:
       RegisterDataset(update.register_dataset());
@@ -151,32 +151,32 @@
 }
 
 Status DispatcherState::DatasetFromId(
-    int64 id, std::shared_ptr<const Dataset>* dataset) const {
+    int64 id, std::shared_ptr<const Dataset>& dataset) const {
   auto it = datasets_by_id_.find(id);
   if (it == datasets_by_id_.end()) {
     return errors::NotFound("Dataset id ", id, " not found");
   }
-  *dataset = it->second;
+  dataset = it->second;
   return Status::OK();
 }
 
 Status DispatcherState::DatasetFromFingerprint(
-    uint64 fingerprint, std::shared_ptr<const Dataset>* dataset) const {
+    uint64 fingerprint, std::shared_ptr<const Dataset>& dataset) const {
   auto it = datasets_by_fingerprint_.find(fingerprint);
   if (it == datasets_by_fingerprint_.end()) {
     return errors::NotFound("Dataset fingerprint ", fingerprint, " not found");
   }
-  *dataset = it->second;
+  dataset = it->second;
   return Status::OK();
 }
 
 Status DispatcherState::WorkerFromAddress(
-    const std::string& address, std::shared_ptr<const Worker>* worker) const {
+    const std::string& address, std::shared_ptr<const Worker>& worker) const {
   auto it = workers_.find(address);
   if (it == workers_.end()) {
     return errors::NotFound("Worker with address ", address, " not found.");
   }
-  *worker = it->second;
+  worker = it->second;
   return Status::OK();
 }
 
@@ -201,23 +201,23 @@
 }
 
 Status DispatcherState::JobFromId(int64 id,
-                                  std::shared_ptr<const Job>* job) const {
+                                  std::shared_ptr<const Job>& job) const {
   auto it = jobs_.find(id);
   if (it == jobs_.end()) {
     return errors::NotFound("Job id ", id, " not found");
   }
-  *job = it->second;
+  job = it->second;
   return Status::OK();
 }
 
 Status DispatcherState::NamedJobByKey(NamedJobKey named_job_key,
-                                      std::shared_ptr<const Job>* job) const {
+                                      std::shared_ptr<const Job>& job) const {
   auto it = named_jobs_.find(named_job_key);
   if (it == named_jobs_.end()) {
     return errors::NotFound("Named job key (", named_job_key.name, ", ",
                             named_job_key.index, ") not found");
   }
-  *job = it->second;
+  job = it->second;
   return Status::OK();
 }
 
@@ -239,25 +239,25 @@
 }
 
 Status DispatcherState::TaskFromId(int64 id,
-                                   std::shared_ptr<const Task>* task) const {
+                                   std::shared_ptr<const Task>& task) const {
   auto it = tasks_.find(id);
   if (it == tasks_.end()) {
     return errors::NotFound("Task ", id, " not found");
   }
-  *task = it->second;
+  task = it->second;
   return Status::OK();
 }
 
 Status DispatcherState::TasksForJob(
-    int64 job_id, std::vector<std::shared_ptr<const Task>>* tasks) const {
+    int64 job_id, std::vector<std::shared_ptr<const Task>>& tasks) const {
   auto it = tasks_by_job_.find(job_id);
   if (it == tasks_by_job_.end()) {
     return errors::NotFound("Job ", job_id, " not found");
   }
-  tasks->clear();
-  tasks->reserve(it->second.size());
+  tasks.clear();
+  tasks.reserve(it->second.size());
   for (const auto& task : it->second) {
-    tasks->push_back(task);
+    tasks.push_back(task);
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h
index d2080c8..59d7f19 100644
--- a/tensorflow/core/data/service/dispatcher_state.h
+++ b/tensorflow/core/data/service/dispatcher_state.h
@@ -56,7 +56,7 @@
   DispatcherState& operator=(const DispatcherState&) = delete;
 
   // Applies the given update to the dispatcher's state.
-  Status Apply(Update update);
+  Status Apply(const Update& update);
 
   // A dataset registered with the dispatcher.
   struct Dataset {
@@ -129,15 +129,15 @@
   // Returns the next available dataset id.
   int64 NextAvailableDatasetId() const;
   // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset.
-  Status DatasetFromId(int64 id, std::shared_ptr<const Dataset>* dataset) const;
+  Status DatasetFromId(int64 id, std::shared_ptr<const Dataset>& dataset) const;
   // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such
   // dataset.
   Status DatasetFromFingerprint(uint64 fingerprint,
-                                std::shared_ptr<const Dataset>* dataset) const;
+                                std::shared_ptr<const Dataset>& dataset) const;
 
   // Gets a worker by address. Returns NOT_FOUND if there is no such worker.
   Status WorkerFromAddress(const std::string& address,
-                           std::shared_ptr<const Worker>* worker) const;
+                           std::shared_ptr<const Worker>& worker) const;
   // Lists all workers registered with the dispatcher.
   std::vector<std::shared_ptr<const Worker>> ListWorkers() const;
 
@@ -146,9 +146,9 @@
   // Returns a list of all jobs.
   std::vector<std::shared_ptr<const Job>> ListJobs();
   // Gets a job by id. Returns NOT_FOUND if there is no such job.
-  Status JobFromId(int64 id, std::shared_ptr<const Job>* job) const;
+  Status JobFromId(int64 id, std::shared_ptr<const Job>& job) const;
   // Gets a named job by key. Returns NOT_FOUND if there is no such job.
-  Status NamedJobByKey(NamedJobKey key, std::shared_ptr<const Job>* job) const;
+  Status NamedJobByKey(NamedJobKey key, std::shared_ptr<const Job>& job) const;
 
   // Returns the job associated with the given job client id. Returns NOT_FOUND
   // if the job_client_id is unknown or has been released.
@@ -160,12 +160,12 @@
   // Returns the next available task id.
   int64 NextAvailableTaskId() const;
   // Gets a task by id. Returns NOT_FOUND if there is no such task.
-  Status TaskFromId(int64 id, std::shared_ptr<const Task>* task) const;
-  // Stores a list of all tasks for the given job to `*tasks`. Returns NOT_FOUND
+  Status TaskFromId(int64 id, std::shared_ptr<const Task>& task) const;
+  // Stores a list of all tasks for the given job to `tasks`. Returns NOT_FOUND
   // if there is no such job.
   Status TasksForJob(int64 job_id,
-                     std::vector<std::shared_ptr<const Task>>* tasks) const;
-  // Stores a list of all tasks for the given worker to `*tasks`. Returns
+                     std::vector<std::shared_ptr<const Task>>& tasks) const;
+  // Stores a list of all tasks for the given worker to `tasks`. Returns
   // NOT_FOUND if there is no such worker.
   Status TasksForWorker(const absl::string_view worker_address,
                         std::vector<std::shared_ptr<const Task>>& tasks) const;
diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc
index 1676fc7..43a47f8 100644
--- a/tensorflow/core/data/service/dispatcher_state_test.cc
+++ b/tensorflow/core/data/service/dispatcher_state_test.cc
@@ -36,39 +36,39 @@
 using ::testing::IsEmpty;
 using ::testing::SizeIs;
 
-Status RegisterDataset(int64 id, uint64 fingerprint, DispatcherState* state) {
+Status RegisterDataset(int64 id, uint64 fingerprint, DispatcherState& state) {
   Update update;
   RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
   register_dataset->set_dataset_id(id);
   register_dataset->set_fingerprint(fingerprint);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 
-Status RegisterDataset(int64 id, DispatcherState* state) {
+Status RegisterDataset(int64 id, DispatcherState& state) {
   return RegisterDataset(id, /*fingerprint=*/1, state);
 }
 
-Status RegisterWorker(std::string worker_address, DispatcherState* state) {
+Status RegisterWorker(std::string worker_address, DispatcherState& state) {
   Update update;
   update.mutable_register_worker()->set_worker_address(worker_address);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 
 Status CreateAnonymousJob(int64 job_id, int64 dataset_id,
-                          DispatcherState* state) {
+                          DispatcherState& state) {
   Update update;
   CreateJobUpdate* create_job = update.mutable_create_job();
   create_job->set_job_id(job_id);
   create_job->set_dataset_id(dataset_id);
   create_job->set_processing_mode(ProcessingModeDef::PARALLEL_EPOCHS);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 
 Status CreateNamedJob(int64 job_id, int64 dataset_id, NamedJobKey named_job_key,
-                      DispatcherState* state) {
+                      DispatcherState& state) {
   Update update;
   CreateJobUpdate* create_job = update.mutable_create_job();
   create_job->set_job_id(job_id);
@@ -77,49 +77,49 @@
   NamedJobKeyDef* key = create_job->mutable_named_job_key();
   key->set_name(named_job_key.name);
   key->set_index(named_job_key.index);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 
 Status AcquireJobClientId(int64 job_id, int64 job_client_id,
-                          DispatcherState* state) {
+                          DispatcherState& state) {
   Update update;
   AcquireJobClientUpdate* acquire_job_client =
       update.mutable_acquire_job_client();
   acquire_job_client->set_job_id(job_id);
   acquire_job_client->set_job_client_id(job_client_id);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 
 Status ReleaseJobClientId(int64 job_client_id, int64 release_time,
-                          DispatcherState* state) {
+                          DispatcherState& state) {
   Update update;
   ReleaseJobClientUpdate* release_job_client =
       update.mutable_release_job_client();
   release_job_client->set_job_client_id(job_client_id);
   release_job_client->set_time_micros(release_time);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 
 Status CreateTask(int64 task_id, int64 job_id, int64 dataset_id,
-                  const std::string& worker_address, DispatcherState* state) {
+                  const std::string& worker_address, DispatcherState& state) {
   Update update;
   CreateTaskUpdate* create_task = update.mutable_create_task();
   create_task->set_task_id(task_id);
   create_task->set_job_id(job_id);
   create_task->set_dataset_id(dataset_id);
   create_task->set_worker_address(worker_address);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 
-Status FinishTask(int64 task_id, DispatcherState* state) {
+Status FinishTask(int64 task_id, DispatcherState& state) {
   Update update;
   FinishTaskUpdate* finish_task = update.mutable_finish_task();
   finish_task->set_task_id(task_id);
-  TF_RETURN_IF_ERROR(state->Apply(update));
+  TF_RETURN_IF_ERROR(state.Apply(update));
   return Status::OK();
 }
 }  // namespace
@@ -128,17 +128,17 @@
   int64 id = 10;
   uint64 fingerprint = 20;
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(id, fingerprint, &state));
+  TF_EXPECT_OK(RegisterDataset(id, fingerprint, state));
   EXPECT_EQ(state.NextAvailableDatasetId(), id + 1);
 
   {
     std::shared_ptr<const Dataset> dataset;
-    TF_EXPECT_OK(state.DatasetFromFingerprint(fingerprint, &dataset));
+    TF_EXPECT_OK(state.DatasetFromFingerprint(fingerprint, dataset));
     EXPECT_EQ(dataset->dataset_id, id);
   }
   {
     std::shared_ptr<const Dataset> dataset;
-    TF_EXPECT_OK(state.DatasetFromId(id, &dataset));
+    TF_EXPECT_OK(state.DatasetFromId(id, dataset));
     EXPECT_EQ(dataset->fingerprint, fingerprint);
   }
 }
@@ -146,14 +146,14 @@
 TEST(DispatcherState, MissingDatasetId) {
   DispatcherState state;
   std::shared_ptr<const Dataset> dataset;
-  Status s = state.DatasetFromId(0, &dataset);
+  Status s = state.DatasetFromId(0, dataset);
   EXPECT_EQ(s.code(), error::NOT_FOUND);
 }
 
 TEST(DispatcherState, MissingDatasetFingerprint) {
   DispatcherState state;
   std::shared_ptr<const Dataset> dataset;
-  Status s = state.DatasetFromFingerprint(0, &dataset);
+  Status s = state.DatasetFromFingerprint(0, dataset);
   EXPECT_EQ(s.code(), error::NOT_FOUND);
 }
 
@@ -161,7 +161,7 @@
   DispatcherState state;
   int64 id = state.NextAvailableDatasetId();
   uint64 fingerprint = 20;
-  TF_EXPECT_OK(RegisterDataset(id, fingerprint, &state));
+  TF_EXPECT_OK(RegisterDataset(id, fingerprint, state));
   EXPECT_NE(state.NextAvailableDatasetId(), id);
   EXPECT_EQ(state.NextAvailableDatasetId(), state.NextAvailableDatasetId());
 }
@@ -169,9 +169,9 @@
 TEST(DispatcherState, RegisterWorker) {
   DispatcherState state;
   std::string address = "test_worker_address";
-  TF_EXPECT_OK(RegisterWorker(address, &state));
+  TF_EXPECT_OK(RegisterWorker(address, state));
   std::shared_ptr<const Worker> worker;
-  TF_EXPECT_OK(state.WorkerFromAddress(address, &worker));
+  TF_EXPECT_OK(state.WorkerFromAddress(address, worker));
   EXPECT_EQ(worker->address, address);
 }
 
@@ -183,12 +183,12 @@
     std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
     EXPECT_THAT(workers, IsEmpty());
   }
-  TF_EXPECT_OK(RegisterWorker(address_1, &state));
+  TF_EXPECT_OK(RegisterWorker(address_1, state));
   {
     std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
     EXPECT_THAT(workers, SizeIs(1));
   }
-  TF_EXPECT_OK(RegisterWorker(address_2, &state));
+  TF_EXPECT_OK(RegisterWorker(address_2, state));
   {
     std::vector<std::shared_ptr<const Worker>> workers = state.ListWorkers();
     EXPECT_THAT(workers, SizeIs(2));
@@ -198,7 +198,7 @@
 TEST(DispatcherState, MissingWorker) {
   DispatcherState state;
   std::shared_ptr<const Worker> worker;
-  Status s = state.WorkerFromAddress("test_worker_address", &worker);
+  Status s = state.WorkerFromAddress("test_worker_address", worker);
   EXPECT_EQ(s.code(), error::NOT_FOUND);
 }
 
@@ -213,15 +213,15 @@
   int64 job_id = 3;
   int64 dataset_id = 10;
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
   std::shared_ptr<const Job> job;
-  TF_EXPECT_OK(state.JobFromId(job_id, &job));
+  TF_EXPECT_OK(state.JobFromId(job_id, job));
   EXPECT_EQ(state.NextAvailableJobId(), job_id + 1);
   EXPECT_EQ(job->dataset_id, dataset_id);
   EXPECT_EQ(job->job_id, job_id);
   std::vector<std::shared_ptr<const Task>> tasks;
-  TF_EXPECT_OK(state.TasksForJob(job_id, &tasks));
+  TF_EXPECT_OK(state.TasksForJob(job_id, tasks));
   EXPECT_THAT(tasks, IsEmpty());
   EXPECT_FALSE(job->finished);
 }
@@ -230,11 +230,11 @@
   int64 job_id = 3;
   int64 dataset_id = 10;
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
   NamedJobKey named_job_key("test", 1);
-  TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, &state));
+  TF_EXPECT_OK(CreateNamedJob(job_id, dataset_id, named_job_key, state));
   std::shared_ptr<const Job> job;
-  TF_EXPECT_OK(state.NamedJobByKey(named_job_key, &job));
+  TF_EXPECT_OK(state.NamedJobByKey(named_job_key, job));
   EXPECT_EQ(state.NextAvailableJobId(), job_id + 1);
   EXPECT_EQ(job->dataset_id, dataset_id);
   EXPECT_EQ(job->job_id, job_id);
@@ -247,13 +247,13 @@
   int64 task_id = 8;
   std::string worker_address = "test_worker_address";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
-  TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
+  TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state));
   EXPECT_EQ(state.NextAvailableTaskId(), task_id + 1);
   {
     std::shared_ptr<const Task> task;
-    TF_EXPECT_OK(state.TaskFromId(task_id, &task));
+    TF_EXPECT_OK(state.TaskFromId(task_id, task));
     EXPECT_EQ(task->task_id, task_id);
     EXPECT_EQ(task->job_id, job_id);
     EXPECT_EQ(task->dataset_id, dataset_id);
@@ -261,7 +261,7 @@
   }
   {
     std::vector<std::shared_ptr<const Task>> tasks;
-    TF_EXPECT_OK(state.TasksForJob(job_id, &tasks));
+    TF_EXPECT_OK(state.TasksForJob(job_id, tasks));
     EXPECT_THAT(tasks, SizeIs(1));
   }
   {
@@ -278,15 +278,15 @@
   int64 task_id_2 = 9;
   std::string worker_address = "test_worker_address";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_1, job_id, dataset_id, worker_address, &state));
+      CreateTask(task_id_1, job_id, dataset_id, worker_address, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_2, job_id, dataset_id, worker_address, &state));
+      CreateTask(task_id_2, job_id, dataset_id, worker_address, state));
   {
     std::vector<std::shared_ptr<const Task>> tasks;
-    TF_EXPECT_OK(state.TasksForJob(job_id, &tasks));
+    TF_EXPECT_OK(state.TasksForJob(job_id, tasks));
     EXPECT_THAT(tasks, SizeIs(2));
   }
 }
@@ -299,21 +299,21 @@
   int64 task_id_2 = 9;
   std::string worker_address = "test_worker_address";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id_1, dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id_2, dataset_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id_1, dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id_2, dataset_id, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_1, job_id_1, dataset_id, worker_address, &state));
+      CreateTask(task_id_1, job_id_1, dataset_id, worker_address, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_2, job_id_2, dataset_id, worker_address, &state));
+      CreateTask(task_id_2, job_id_2, dataset_id, worker_address, state));
   {
     std::vector<std::shared_ptr<const Task>> tasks;
-    TF_EXPECT_OK(state.TasksForJob(job_id_1, &tasks));
+    TF_EXPECT_OK(state.TasksForJob(job_id_1, tasks));
     EXPECT_THAT(tasks, SizeIs(1));
   }
   {
     std::vector<std::shared_ptr<const Task>> tasks;
-    TF_EXPECT_OK(state.TasksForJob(job_id_2, &tasks));
+    TF_EXPECT_OK(state.TasksForJob(job_id_2, tasks));
     EXPECT_THAT(tasks, SizeIs(1));
   }
 }
@@ -325,12 +325,12 @@
   int64 task_id_2 = 9;
   std::string worker_address = "test_worker_address";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_1, job_id, dataset_id, worker_address, &state));
+      CreateTask(task_id_1, job_id, dataset_id, worker_address, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_2, job_id, dataset_id, worker_address, &state));
+      CreateTask(task_id_2, job_id, dataset_id, worker_address, state));
   {
     std::vector<std::shared_ptr<const Task>> tasks;
     TF_EXPECT_OK(state.TasksForWorker(worker_address, tasks));
@@ -346,12 +346,12 @@
   std::string worker_address_1 = "test_worker_address_1";
   std::string worker_address_2 = "test_worker_address_2";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_1, job_id, dataset_id, worker_address_1, &state));
+      CreateTask(task_id_1, job_id, dataset_id, worker_address_1, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_2, job_id, dataset_id, worker_address_2, &state));
+      CreateTask(task_id_2, job_id, dataset_id, worker_address_2, state));
   {
     std::vector<std::shared_ptr<const Task>> tasks;
     TF_EXPECT_OK(state.TasksForWorker(worker_address_1, tasks));
@@ -367,7 +367,7 @@
 TEST(DispatcherState, GetTasksForWorkerEmpty) {
   std::string worker_address = "test_worker_address";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterWorker(worker_address, &state));
+  TF_EXPECT_OK(RegisterWorker(worker_address, state));
   {
     std::vector<std::shared_ptr<const Task>> tasks;
     TF_EXPECT_OK(state.TasksForWorker(worker_address, tasks));
@@ -381,15 +381,15 @@
   int64 task_id = 4;
   std::string worker_address = "test_worker_address";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
-  TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, &state));
-  TF_EXPECT_OK(FinishTask(task_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
+  TF_EXPECT_OK(CreateTask(task_id, job_id, dataset_id, worker_address, state));
+  TF_EXPECT_OK(FinishTask(task_id, state));
   std::shared_ptr<const Task> task;
-  TF_EXPECT_OK(state.TaskFromId(task_id, &task));
+  TF_EXPECT_OK(state.TaskFromId(task_id, task));
   EXPECT_TRUE(task->finished);
   std::shared_ptr<const Job> job;
-  TF_EXPECT_OK(state.JobFromId(job_id, &job));
+  TF_EXPECT_OK(state.JobFromId(job_id, job));
   EXPECT_TRUE(job->finished);
 }
 
@@ -400,24 +400,24 @@
   int64 task_id_2 = 5;
   std::string worker_address = "test_worker_address";
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_1, job_id, dataset_id, worker_address, &state));
+      CreateTask(task_id_1, job_id, dataset_id, worker_address, state));
   TF_EXPECT_OK(
-      CreateTask(task_id_2, job_id, dataset_id, worker_address, &state));
+      CreateTask(task_id_2, job_id, dataset_id, worker_address, state));
 
-  TF_EXPECT_OK(FinishTask(task_id_1, &state));
+  TF_EXPECT_OK(FinishTask(task_id_1, state));
   {
     std::shared_ptr<const Job> job;
-    TF_EXPECT_OK(state.JobFromId(job_id, &job));
+    TF_EXPECT_OK(state.JobFromId(job_id, job));
     EXPECT_FALSE(job->finished);
   }
 
-  TF_EXPECT_OK(FinishTask(task_id_2, &state));
+  TF_EXPECT_OK(FinishTask(task_id_2, state));
   {
     std::shared_ptr<const Job> job;
-    TF_EXPECT_OK(state.JobFromId(job_id, &job));
+    TF_EXPECT_OK(state.JobFromId(job_id, job));
     EXPECT_TRUE(job->finished);
   }
 }
@@ -428,14 +428,14 @@
   int64 job_client_id_2 = 2;
   int64 dataset_id = 10;
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
-  TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_1, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
+  TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_1, state));
   {
     std::shared_ptr<const Job> job;
-    TF_EXPECT_OK(state.JobFromId(job_id, &job));
+    TF_EXPECT_OK(state.JobFromId(job_id, job));
     EXPECT_EQ(job->num_clients, 1);
-    TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_2, &state));
+    TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id_2, state));
     EXPECT_EQ(job->num_clients, 2);
   }
   {
@@ -456,12 +456,12 @@
   int64 job_client_id = 6;
   int64 release_time = 100;
   DispatcherState state;
-  TF_EXPECT_OK(RegisterDataset(dataset_id, &state));
-  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, &state));
-  TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id, &state));
-  TF_EXPECT_OK(ReleaseJobClientId(job_client_id, release_time, &state));
+  TF_EXPECT_OK(RegisterDataset(dataset_id, state));
+  TF_EXPECT_OK(CreateAnonymousJob(job_id, dataset_id, state));
+  TF_EXPECT_OK(AcquireJobClientId(job_id, job_client_id, state));
+  TF_EXPECT_OK(ReleaseJobClientId(job_client_id, release_time, state));
   std::shared_ptr<const Job> job;
-  TF_EXPECT_OK(state.JobFromId(job_id, &job));
+  TF_EXPECT_OK(state.JobFromId(job_id, job));
   EXPECT_EQ(job->num_clients, 0);
   Status s = state.JobForJobClientId(job_client_id, job);
   EXPECT_EQ(s.code(), error::NOT_FOUND);
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc
index a7a3079..fbfc5d2 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc
@@ -26,9 +26,9 @@
 using ::grpc::ServerContext;
 
 GrpcDispatcherImpl::GrpcDispatcherImpl(
-    ServerBuilder* server_builder, const experimental::DispatcherConfig& config)
+    const experimental::DispatcherConfig& config, ServerBuilder& server_builder)
     : impl_(config) {
-  server_builder->RegisterService(this);
+  server_builder.RegisterService(this);
   VLOG(1) << "Registered data service dispatcher";
 }
 
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h
index 81f1cbf..171deed 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.h
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h
@@ -25,18 +25,12 @@
 namespace data {
 
 // This class is a wrapper that handles communication for gRPC.
-//
-// Example usage:
-//
-// ::grpc::ServerBuilder builder;
-// // configure builder
-// GrpcDispatcherImpl data_service(&builder);
-// builder.BuildAndStart()
-//
 class GrpcDispatcherImpl : public DispatcherService::Service {
  public:
-  explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder,
-                              const experimental::DispatcherConfig& config);
+  // Constructs a GrpcDispatcherImpl with the given config, and registers it
+  // with `server_builder`.
+  explicit GrpcDispatcherImpl(const experimental::DispatcherConfig& config,
+                              ::grpc::ServerBuilder& server_builder);
   ~GrpcDispatcherImpl() override {}
 
   Status Start();
diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc
index b3a37fe..ef386be 100644
--- a/tensorflow/core/data/service/grpc_worker_impl.cc
+++ b/tensorflow/core/data/service/grpc_worker_impl.cc
@@ -24,10 +24,10 @@
 using ::grpc::ServerBuilder;
 using ::grpc::ServerContext;
 
-GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
-                               const experimental::WorkerConfig& config)
+GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config,
+                               ServerBuilder& server_builder)
     : impl_(config) {
-  server_builder->RegisterService(this);
+  server_builder.RegisterService(this);
   VLOG(1) << "Registered data service worker";
 }
 
diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h
index c42e563..3d30af9 100644
--- a/tensorflow/core/data/service/grpc_worker_impl.h
+++ b/tensorflow/core/data/service/grpc_worker_impl.h
@@ -25,18 +25,12 @@
 namespace data {
 
 // This class is a wrapper that handles communication for gRPC.
-//
-// Example usage:
-//
-// ::grpc::ServerBuilder builder;
-// // configure builder
-// GrpcWorkerImpl data_service(&builder);
-// builder.BuildAndStart()
-//
 class GrpcWorkerImpl : public WorkerService::Service {
  public:
-  explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder,
-                          const experimental::WorkerConfig& config);
+  // Constructs a GrpcWorkerImpl with the given config, and registers it with
+  // `server_builder`.
+  explicit GrpcWorkerImpl(const experimental::WorkerConfig& config,
+                          ::grpc::ServerBuilder& server_builder);
   ~GrpcWorkerImpl() override {}
 
   Status Start(const std::string& worker_address);
diff --git a/tensorflow/core/data/service/journal.cc b/tensorflow/core/data/service/journal.cc
index b0ce087..979fc78 100644
--- a/tensorflow/core/data/service/journal.cc
+++ b/tensorflow/core/data/service/journal.cc
@@ -96,7 +96,7 @@
   return UpdateFile(DataServiceJournalFile(journal_dir_, 0));
 }
 
-Status FileJournalReader::Read(Update* update, bool* end_of_journal) {
+Status FileJournalReader::Read(Update& update, bool& end_of_journal) {
   TF_RETURN_IF_ERROR(EnsureInitialized());
   while (true) {
     tstring record;
@@ -108,20 +108,20 @@
       if (errors::IsNotFound(env_->FileExists(next_journal_file))) {
         VLOG(3) << "Next journal file " << next_journal_file
                 << " does not exist. End of journal reached.";
-        *end_of_journal = true;
+        end_of_journal = true;
         return Status::OK();
       }
       TF_RETURN_IF_ERROR(UpdateFile(next_journal_file));
       continue;
     }
     TF_RETURN_IF_ERROR(s);
-    if (!update->ParseFromString(record)) {
+    if (!update.ParseFromString(record)) {
       return errors::DataLoss("Failed to parse journal record.");
     }
     if (VLOG_IS_ON(4)) {
-      VLOG(4) << "Read journal entry: " << update->DebugString();
+      VLOG(4) << "Read journal entry: " << update.DebugString();
     }
-    *end_of_journal = false;
+    end_of_journal = false;
     return Status::OK();
   }
 }
diff --git a/tensorflow/core/data/service/journal.h b/tensorflow/core/data/service/journal.h
index 3483497..e31830e 100644
--- a/tensorflow/core/data/service/journal.h
+++ b/tensorflow/core/data/service/journal.h
@@ -77,9 +77,9 @@
 class JournalReader {
  public:
   virtual ~JournalReader() = default;
-  // Reads the next update from the journal. Sets `*end_of_journal=true` if
+  // Reads the next update from the journal. Sets `end_of_journal=true` if
   // there are no more updates left in the journal.
-  virtual Status Read(Update* update, bool* end_of_journal) = 0;
+  virtual Status Read(Update& update, bool& end_of_journal) = 0;
 };
 
 // JournalReader is not thread-safe, requiring external synchronization when
@@ -93,7 +93,7 @@
   FileJournalReader(const FileJournalReader&) = delete;
   FileJournalReader& operator=(const FileJournalReader&) = delete;
 
-  Status Read(Update* update, bool* end_of_journal) override;
+  Status Read(Update& update, bool& end_of_journal) override;
 
  private:
   // Initializes the reader if it is not yet initialized.
diff --git a/tensorflow/core/data/service/journal_test.cc b/tensorflow/core/data/service/journal_test.cc
index 313b216..3f55447 100644
--- a/tensorflow/core/data/service/journal_test.cc
+++ b/tensorflow/core/data/service/journal_test.cc
@@ -28,12 +28,12 @@
 namespace {
 using ::testing::HasSubstr;
 
-bool NewJournalDir(std::string* journal_dir) {
+bool NewJournalDir(std::string& journal_dir) {
   std::string filename = testing::TmpDir();
   if (!Env::Default()->CreateUniqueFileName(&filename, "journal_dir")) {
     return false;
   }
-  *journal_dir = filename;
+  journal_dir = filename;
   return true;
 }
 
@@ -67,7 +67,7 @@
   for (const auto& update : expected) {
     Update result;
     bool end_of_journal = true;
-    TF_RETURN_IF_ERROR(reader.Read(&result, &end_of_journal));
+    TF_RETURN_IF_ERROR(reader.Read(result, end_of_journal));
     EXPECT_FALSE(end_of_journal);
     // We can't use the testing::EqualsProto matcher because it is not available
     // in OSS.
@@ -75,7 +75,7 @@
   }
   Update result;
   bool end_of_journal = false;
-  TF_RETURN_IF_ERROR(reader.Read(&result, &end_of_journal));
+  TF_RETURN_IF_ERROR(reader.Read(result, end_of_journal));
   EXPECT_TRUE(end_of_journal);
   return Status::OK();
 }
@@ -83,7 +83,7 @@
 
 TEST(Journal, RoundTripMultiple) {
   std::string journal_dir;
-  EXPECT_TRUE(NewJournalDir(&journal_dir));
+  EXPECT_TRUE(NewJournalDir(journal_dir));
   std::vector<Update> updates = {MakeCreateJobUpdate(),
                                  MakeRegisterDatasetUpdate(),
                                  MakeFinishTaskUpdate()};
@@ -97,7 +97,7 @@
 
 TEST(Journal, AppendExistingJournal) {
   std::string journal_dir;
-  EXPECT_TRUE(NewJournalDir(&journal_dir));
+  EXPECT_TRUE(NewJournalDir(journal_dir));
   std::vector<Update> updates = {MakeCreateJobUpdate(),
                                  MakeRegisterDatasetUpdate(),
                                  MakeFinishTaskUpdate()};
@@ -111,17 +111,17 @@
 
 TEST(Journal, MissingFile) {
   std::string journal_dir;
-  EXPECT_TRUE(NewJournalDir(&journal_dir));
+  EXPECT_TRUE(NewJournalDir(journal_dir));
   FileJournalReader reader(Env::Default(), journal_dir);
   Update result;
   bool end_of_journal = true;
-  Status s = reader.Read(&result, &end_of_journal);
+  Status s = reader.Read(result, end_of_journal);
   EXPECT_TRUE(errors::IsNotFound(s));
 }
 
 TEST(Journal, NonRecordData) {
   std::string journal_dir;
-  EXPECT_TRUE(NewJournalDir(&journal_dir));
+  EXPECT_TRUE(NewJournalDir(journal_dir));
 
   TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(journal_dir));
   {
@@ -134,14 +134,14 @@
   FileJournalReader reader(Env::Default(), journal_dir);
   Update result;
   bool end_of_journal = true;
-  Status s = reader.Read(&result, &end_of_journal);
+  Status s = reader.Read(result, end_of_journal);
   EXPECT_THAT(s.error_message(), HasSubstr("corrupted record"));
   EXPECT_EQ(s.code(), error::DATA_LOSS);
 }
 
 TEST(Journal, InvalidRecordData) {
   std::string journal_dir;
-  EXPECT_TRUE(NewJournalDir(&journal_dir));
+  EXPECT_TRUE(NewJournalDir(journal_dir));
 
   TF_ASSERT_OK(Env::Default()->RecursivelyCreateDir(journal_dir));
   {
@@ -155,7 +155,7 @@
   FileJournalReader reader(Env::Default(), journal_dir);
   Update result;
   bool end_of_journal = true;
-  Status s = reader.Read(&result, &end_of_journal);
+  Status s = reader.Read(result, end_of_journal);
   EXPECT_THAT(s.error_message(), HasSubstr("Failed to parse journal record"));
   EXPECT_EQ(s.code(), error::DATA_LOSS);
 }
diff --git a/tensorflow/core/data/service/local_credentials_factory.cc b/tensorflow/core/data/service/local_credentials_factory.cc
index 136bf49..b9426e7 100644
--- a/tensorflow/core/data/service/local_credentials_factory.cc
+++ b/tensorflow/core/data/service/local_credentials_factory.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "absl/memory/memory.h"
 #include "tensorflow/core/data/service/credentials_factory.h"
 
 namespace tensorflow {
@@ -23,14 +24,14 @@
   std::string Protocol() override { return "grpc+local"; }
 
   Status CreateServerCredentials(
-      std::shared_ptr<::grpc::ServerCredentials>* out) override {
-    *out = grpc::experimental::LocalServerCredentials(LOCAL_TCP);
+      std::shared_ptr<::grpc::ServerCredentials>& out) override {
+    out = grpc::experimental::LocalServerCredentials(LOCAL_TCP);
     return Status::OK();
   }
 
   Status CreateClientCredentials(
-      std::shared_ptr<::grpc::ChannelCredentials>* out) override {
-    *out = grpc::experimental::LocalCredentials(LOCAL_TCP);
+      std::shared_ptr<::grpc::ChannelCredentials>& out) override {
+    out = grpc::experimental::LocalCredentials(LOCAL_TCP);
     return Status::OK();
   }
 };
@@ -38,8 +39,7 @@
 class LocalCredentialsRegistrar {
  public:
   LocalCredentialsRegistrar() {
-    auto factory = new LocalCredentialsFactory();
-    CredentialsFactory::Register(factory);
+    CredentialsFactory::Register(absl::make_unique<LocalCredentialsFactory>());
   }
 };
 static LocalCredentialsRegistrar registrar;
diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc
index 4ee186c..83a6e67 100644
--- a/tensorflow/core/data/service/server_lib.cc
+++ b/tensorflow/core/data/service/server_lib.cc
@@ -46,13 +46,13 @@
   ::grpc::ServerBuilder builder;
   std::shared_ptr<::grpc::ServerCredentials> credentials;
   TF_RETURN_IF_ERROR(
-      CredentialsFactory::CreateServerCredentials(protocol_, &credentials));
+      CredentialsFactory::CreateServerCredentials(protocol_, credentials));
   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
                            credentials, &bound_port_);
   builder.SetMaxReceiveMessageSize(-1);
 
-  AddDataServiceToBuilder(&builder);
-  AddProfilerServiceToBuilder(&builder);
+  AddDataServiceToBuilder(builder);
+  AddProfilerServiceToBuilder(builder);
   server_ = builder.BuildAndStart();
   if (!server_) {
     return errors::Internal("Could not start gRPC server");
@@ -81,9 +81,9 @@
 int GrpcDataServerBase::BoundPort() { return bound_port(); }
 
 void GrpcDataServerBase::AddProfilerServiceToBuilder(
-    ::grpc::ServerBuilder* builder) {
-  profiler_service_ = CreateProfilerService();
-  builder->RegisterService(profiler_service_.get());
+    ::grpc::ServerBuilder& builder) {
+  profiler_service_ = profiler::CreateProfilerService();
+  builder.RegisterService(profiler_service_.get());
 }
 
 DispatchGrpcDataServer::DispatchGrpcDataServer(
@@ -94,8 +94,8 @@
 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
 
 void DispatchGrpcDataServer::AddDataServiceToBuilder(
-    ::grpc::ServerBuilder* builder) {
-  service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
+    ::grpc::ServerBuilder& builder) {
+  service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
 }
 
 Status DispatchGrpcDataServer::StartServiceInternal() {
@@ -122,8 +122,8 @@
 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
 
 void WorkerGrpcDataServer::AddDataServiceToBuilder(
-    ::grpc::ServerBuilder* builder) {
-  service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
+    ::grpc::ServerBuilder& builder) {
+  service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
 }
 
 Status WorkerGrpcDataServer::StartServiceInternal() {
@@ -139,14 +139,14 @@
 }
 
 Status NewDispatchServer(const experimental::DispatcherConfig& config,
-                         std::unique_ptr<DispatchGrpcDataServer>* out_server) {
-  *out_server = absl::make_unique<DispatchGrpcDataServer>(config);
+                         std::unique_ptr<DispatchGrpcDataServer>& out_server) {
+  out_server = absl::make_unique<DispatchGrpcDataServer>(config);
   return Status::OK();
 }
 
 Status NewWorkerServer(const experimental::WorkerConfig& config,
-                       std::unique_ptr<WorkerGrpcDataServer>* out_server) {
-  *out_server = absl::make_unique<WorkerGrpcDataServer>(config);
+                       std::unique_ptr<WorkerGrpcDataServer>& out_server) {
+  out_server = absl::make_unique<WorkerGrpcDataServer>(config);
   return Status::OK();
 }
 
diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h
index 0ddc806..c45ec14 100644
--- a/tensorflow/core/data/service/server_lib.h
+++ b/tensorflow/core/data/service/server_lib.h
@@ -53,8 +53,8 @@
   int BoundPort();
 
  protected:
-  virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
-  void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder);
+  virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0;
+  void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder);
   // Starts the service. This will be called after building the service, so
   // bound_port() will return the actual bound port.
   virtual Status StartServiceInternal() = 0;
@@ -84,7 +84,7 @@
   Status NumWorkers(int* num_workers);
 
  protected:
-  void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
+  void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
   Status StartServiceInternal() override;
 
  private:
@@ -99,7 +99,7 @@
   ~WorkerGrpcDataServer() override;
 
  protected:
-  void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
+  void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
   Status StartServiceInternal() override;
 
  private:
@@ -108,13 +108,13 @@
   GrpcWorkerImpl* service_;
 };
 
-// Creates a dispatch tf.data server and stores it in `*out_server`.
+// Creates a dispatch tf.data server and stores it in `out_server`.
 Status NewDispatchServer(const experimental::DispatcherConfig& config,
-                         std::unique_ptr<DispatchGrpcDataServer>* out_server);
+                         std::unique_ptr<DispatchGrpcDataServer>& out_server);
 
-// Creates a worker tf.data server and stores it in `*out_server`.
+// Creates a worker tf.data server and stores it in `out_server`.
 Status NewWorkerServer(const experimental::WorkerConfig& config,
-                       std::unique_ptr<WorkerGrpcDataServer>* out_server);
+                       std::unique_ptr<WorkerGrpcDataServer>& out_server);
 
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc
index 8ae3f19..49f7eae 100644
--- a/tensorflow/core/data/service/test_cluster.cc
+++ b/tensorflow/core/data/service/test_cluster.cc
@@ -49,7 +49,7 @@
   experimental::DispatcherConfig config;
   config.set_port(0);
   config.set_protocol(kProtocol);
-  TF_RETURN_IF_ERROR(NewDispatchServer(config, &dispatcher_));
+  TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_));
   TF_RETURN_IF_ERROR(dispatcher_->Start());
   dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
   workers_.reserve(num_workers_);
@@ -67,7 +67,7 @@
   config.set_protocol(kProtocol);
   config.set_dispatcher_address(dispatcher_address_);
   config.set_worker_address("localhost:%port%");
-  TF_RETURN_IF_ERROR(NewWorkerServer(config, &worker));
+  TF_RETURN_IF_ERROR(NewWorkerServer(config, worker));
   TF_RETURN_IF_ERROR(worker->Start());
   worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
   workers_.push_back(std::move(worker));
diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc
index e23d4ab..cc61c48 100644
--- a/tensorflow/core/data/service/worker_impl.cc
+++ b/tensorflow/core/data/service/worker_impl.cc
@@ -249,7 +249,6 @@
 }
 
 Status DataServiceWorkerImpl::SendTaskUpdates() LOCKS_EXCLUDED(mu_) {
-  WorkerUpdateRequest req;
   std::vector<TaskProgress> task_progress;
   {
     mutex_lock l(mu_);
@@ -265,10 +264,10 @@
 
   TF_RETURN_IF_ERROR(dispatcher_->WorkerUpdate(worker_address_, task_progress));
   mutex_lock l(mu_);
-  for (const auto& update : req.updates()) {
+  for (const auto& update : task_progress) {
     pending_completed_tasks_.erase(update.task_id());
   }
-  VLOG(3) << "Sent " << req.updates().size() << " task updates ";
+  VLOG(3) << "Sent " << task_progress.size() << " task updates ";
   return Status::OK();
 }
 
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 505e0c3..94570c1 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -538,6 +538,7 @@
         "//tensorflow/core:lib_internal",  # protobuf::Any
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:worker_proto_cc",
+        "@com_google_absl//absl/memory",
     ],
 )
 
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index dbc9417..203c63c 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -155,9 +155,6 @@
         PopulateTensorFromExtra(extra, to_tensor);
       }
     }
-    if (!s.ok() && errors::IsFailedPrecondition(s)) {
-      dev_resolver_->ClearTask(peer_task);
-    }
 
     delete state;
     done(s);
@@ -185,6 +182,62 @@
                                           dev_attributes_callback);
 }
 
+void CollectiveRemoteAccessDistributed::CheckPeerHealth(
+    const string& peer_task, const StatusCallback& done) {
+  if (peer_task == task_name_) {
+    // Fast path if the peer is the worker itself.
+    done(Status::OK());
+    return;
+  }
+  // We send a GetStatus RPC with fail_fast=false to check the health of a peer
+  // task. If the RPC succeeds, we verify if the peer_device incarnation matches
+  // the local record if we have it. Note that DeviceResolverInterface always
+  // caches the device attributes.
+  WorkerInterface* wi = worker_cache_->GetOrCreateWorker(peer_task);
+  if (wi == nullptr) {
+    done(errors::InvalidArgument(peer_task,
+                                 " not found. It's probably in valid. The "
+                                 "valid form is /job:xxx/replica:0/task:N"));
+    return;
+  }
+  auto req = new GetStatusRequest();
+  auto resp = new GetStatusResponse();
+  // We're not using Cancellable call because GetStatusAsync doesn't support
+  // cancellation yet.
+  wi->GetStatusAsync(
+      req, resp, /*fail_fast*/ true,
+      [this, req, resp, wi, peer_task, done](Status s) {
+        std::vector<DeviceAttributes> cached_attrs;
+        if (s.ok()) {
+          s = dev_resolver_->GetTaskCached(peer_task, &cached_attrs);
+        }
+        if (s.ok()) {
+          absl::flat_hash_set<uint64> remote_incarnations;
+          for (const DeviceAttributes& da : resp->device_attributes()) {
+            remote_incarnations.insert(da.incarnation());
+          }
+          for (const DeviceAttributes& attr : cached_attrs) {
+            if (!remote_incarnations.contains(attr.incarnation())) {
+              s = errors::FailedPrecondition(
+                  attr.name(), " with incarnation ", attr.incarnation(),
+                  " is not available. This usually means ", peer_task,
+                  " has restarted");
+              break;
+            }
+          }
+        } else if (errors::IsNotFound(s)) {
+          // Skip validating device incarnation if we don't know what the
+          // incarnation should be. The device attribute is cached after the
+          // first collective.
+          s = Status::OK();
+        }
+        delete req;
+        delete resp;
+        worker_cache_->ReleaseWorker(peer_task, wi);
+        done(s);
+      });
+}
+
 void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
   CollectiveRemoteAccessLocal::StartAbort(s);
   cancel_mgr_.StartCancel();
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index d6546e3..ed4d448 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -28,10 +28,11 @@
   CollectiveRemoteAccessDistributed(
       const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
       std::shared_ptr<UnboundedWorkQueue> work_queue,
-      WorkerCacheInterface* worker_cache, int64 step_id)
+      WorkerCacheInterface* worker_cache, int64 step_id, string task_name)
       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
         worker_cache_(worker_cache),
-        work_queue_(std::move(work_queue)) {}
+        work_queue_(std::move(work_queue)),
+        task_name_(std::move(task_name)) {}
 
   ~CollectiveRemoteAccessDistributed() override {}
 
@@ -43,6 +44,9 @@
                     int dev_to_dev_stream_index,
                     const StatusCallback& done) override;
 
+  void CheckPeerHealth(const string& peer_task,
+                       const StatusCallback& done) override;
+
   void StartAbort(const Status& s) override;
 
  protected:
@@ -51,6 +55,7 @@
   // `CollectiveExecutorMgr`.
   std::shared_ptr<UnboundedWorkQueue> work_queue_;
   CancellationManager cancel_mgr_;
+  string task_name_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 2975442..b6975e4 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -63,11 +63,12 @@
 class FakeWorker : public TestWorkerInterface {
  public:
   FakeWorker(const string& name, DeviceMgr* dev_mgr,
-             DeviceResolverDistributed* dres)
+             DeviceResolverDistributed* dres, bool is_failed)
       : name_(name),
         device_mgr_(dev_mgr),
         device_resolver_(dres),
-        buf_rendezvous_(kStepId, dev_mgr) {}
+        buf_rendezvous_(kStepId, dev_mgr),
+        is_failed_(is_failed) {}
 
   // Direct access to a BufRendezvous that holds whatever the remote
   // worker is supposed to have.
@@ -76,6 +77,10 @@
   void GetStatusAsync(const GetStatusRequest* request,
                       GetStatusResponse* response, bool fail_fast,
                       StatusCallback done) override {
+    if (is_failed_) {
+      done(errors::Unavailable("peer down"));
+      return;
+    }
     std::vector<DeviceAttributes> dev_attr;
     device_mgr_->ListDeviceAttributes(&dev_attr);
     for (const auto& da : dev_attr) {
@@ -86,6 +91,10 @@
 
   void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                     RecvBufResponse* response, StatusCallback done) override {
+    if (is_failed_) {
+      done(errors::Unavailable("peer down"));
+      return;
+    }
     opts->SetCancelCallback([this]() {
       // Within this test the call is satisfied by a process-local
       // BufRendezvous table. In real application the BufRendezvous
@@ -125,6 +134,7 @@
   DeviceMgr* device_mgr_;
   DeviceResolverDistributed* device_resolver_;
   BufRendezvous buf_rendezvous_;
+  bool is_failed_;
 };
 
 class FakeCache : public TestWorkerCache {
@@ -201,7 +211,7 @@
     // All tests simulate requests from worker 0 to worker 1.
     rma_.reset(new CollectiveRemoteAccessDistributed(
         device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
-        kStepId));
+        kStepId, "/job:worker/replica:0/task:0"));
 
     const int kNumElts = 8;
     expected_value_ = Tensor(DT_FLOAT, {kNumElts});
@@ -215,7 +225,7 @@
   }
 
   void DefineWorker(const string& worker_name, const string& device_type,
-                    int num_devices) {
+                    int num_devices, bool is_failed = false) {
     std::vector<std::unique_ptr<Device>> devices;
     for (int i = 0; i < num_devices; ++i) {
       devices.push_back(NewDevice(
@@ -232,19 +242,19 @@
     DeviceResolverDistributed* dev_res =
         new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
     dev_resolvers_[worker_name] = dev_res;
-    FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+    FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed);
     workers_.push_back(fw);
     wc_.AddWorker(worker_name, fw);
   }
 
   void RestartWorker(const string& worker_name, const string& device_type,
-                     int num_devices) {
+                     int num_devices, bool is_failed = false) {
     auto it = dev_resolvers_.find(worker_name);
     if (it != dev_resolvers_.end()) {
       delete it->second;
       dev_resolvers_.erase(it);
     }
-    DefineWorker(worker_name, device_type, num_devices);
+    DefineWorker(worker_name, device_type, num_devices, is_failed);
   }
 
   void ValidateResultTensor() {
@@ -401,7 +411,7 @@
   ValidateResultTensor();
 
   // Restart task 1 and check that recv from task 1 to task 0 fails.
-  RestartWorker("/job:worker/replica:0/task:1", "CPU", 1);
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
   Notification post_restart_note;
   rma_->RecvFromPeer(
       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
@@ -417,5 +427,139 @@
   EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
 }
 
+TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:CPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  TF_EXPECT_OK(check_health_status);
+}
+
+TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(check_health_status.ok());
+}
+
+TEST_F(CollRMADistTest, CheckHealthRestarted) {
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:CPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
+}
+
+TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:CPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
+                /*is_failed*/ true);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(errors::IsUnavailable(check_health_status));
+}
+
+TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
+  RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
+
+  DeviceAttributes attr;
+  Status get_attr_status;
+  Notification get_attr_done;
+  // Call GetDeviceAttributesAsync to cache the device attributes of a remote
+  // worker.
+  dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
+      "/job:worker/replica:0/task:1/device:GPU:0",
+      "/job:worker/replica:0/task:1", &attr,
+      [&get_attr_status, &get_attr_done](const Status& s) {
+        get_attr_status = s;
+        get_attr_done.Notify();
+      });
+  get_attr_done.WaitForNotification();
+  TF_ASSERT_OK(get_attr_status);
+
+  RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
+
+  Status check_health_status;
+  Notification check_health_done;
+  rma_->CheckPeerHealth(
+      "/job:worker/replica:0/task:1",
+      [&check_health_status, &check_health_done](const Status s) {
+        check_health_status = s;
+        check_health_done.Notify();
+      });
+  check_health_done.WaitForNotification();
+  EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
index 927925c..8b459c26 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc
@@ -113,25 +113,20 @@
       });
 }
 
-void DeviceResolverDistributed::ClearTask(const string& task) {
+Status DeviceResolverDistributed::GetTaskCached(
+    const string& task, std::vector<DeviceAttributes>* attributes) {
   mutex_lock l(mu_);
-  // First find all the keys belonging to the task.
-  std::unordered_set<string> task_keys;
+  attributes->clear();
   for (const auto& it : attr_table_) {
     const string& device_name = it.first;
     if (DeviceNameUtils::IsSameAddressSpace(task, device_name)) {
-      task_keys.insert(device_name);
+      attributes->push_back(it.second);
     }
   }
-  // Then delete them.
-  for (const string& key : task_keys) {
-    attr_table_.erase(key);
+  if (attributes->empty()) {
+    return errors::NotFound(task, " not found in the cache");
   }
-}
-
-void DeviceResolverDistributed::ClearCache() {
-  mutex_lock l(mu_);
-  attr_table_.clear();
+  return Status::OK();
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
index 93d51a5..a041557 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed.h
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
@@ -43,9 +43,8 @@
                                 DeviceAttributes* attributes,
                                 const StatusCallback& done) override;
 
-  void ClearTask(const string& task) override;
-
-  void ClearCache() override;
+  Status GetTaskCached(const string& task,
+                       std::vector<DeviceAttributes>* attributes) override;
 
  protected:
   // Loads attr_table_ with device attributes retrieved from remote task.
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
index 3d7523f..25f3665 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
@@ -178,55 +178,6 @@
     wc_.AddWorker(worker_name, fw);
   }
 
-  void RestartWorker(const string& worker_name, const string& device_type,
-                     int num_devices, uint64 device_incarnation_base) {
-    for (auto it : resolvers_) {
-      it.second->ClearCache();
-    }
-    // `DefineWorker` creates a device resolver and a worker and adds them to
-    // resolvers_ and workers_.  Recreating the worker would overwrite these map
-    // entries.  We destroy the old device resolver here; all other objects are
-    // cleaned up in the destructor.
-    delete resolvers_[worker_name];
-    DefineWorker(worker_name, device_type, num_devices,
-                 device_incarnation_base);
-  }
-
-  void ResolveIncarnationsAndValidate(
-      const int num_workers, const int num_devices, const string& worker_prefix,
-      const string& device_type,
-      const std::vector<std::vector<uint64>>& expected_incarnations) {
-    for (int w = 0; w < num_workers; ++w) {
-      const string worker_name = absl::StrCat(worker_prefix, w);
-      auto* device_resolver = resolvers_[worker_name];
-      const string device_prefix =
-          absl::StrCat(worker_name, "/device:", device_type, ":");
-      for (int peer_w = 0; peer_w < num_workers; ++peer_w) {
-        const string peer_worker_name = absl::StrCat(worker_prefix, peer_w);
-        for (int d = 0; d < num_devices; ++d) {
-          const string device_name =
-              absl::StrCat(peer_worker_name, "/device:", device_type, ":", d);
-          DeviceNameUtils::ParsedName parsed;
-          ASSERT_TRUE(DeviceNameUtils::ParseFullName(device_name, &parsed));
-          // NOLINT prevents linter from suggesting absl::Notification as a
-          // replacement, which is not available in OSS.
-          Notification note;  // NOLINT
-          Status status;
-          DeviceAttributes attributes;
-          device_resolver->GetDeviceAttributesAsync(
-              device_name, peer_worker_name, &attributes,
-              [&note, &status](const Status& s) {
-                status = s;
-                note.Notify();
-              });
-          note.WaitForNotification();
-          TF_EXPECT_OK(status);
-          EXPECT_EQ(attributes.incarnation(), expected_incarnations[peer_w][d]);
-        }
-      }
-    }
-  }
-
   FakeCache wc_;
   std::vector<DeviceMgr*> device_mgrs_;
   std::unordered_map<string, TestableDeviceResolverDistributed*> resolvers_;
@@ -259,52 +210,6 @@
       }
     }
   }
-  // Clear just task 0 from all.
-  const string w0_name = "/job:worker/replica:0/task:0";
-  for (auto it : resolvers_) {
-    if (it.first == w0_name) continue;
-    TestableDeviceResolverDistributed* dres = it.second;
-    EXPECT_EQ(8, it.second->attr_table().size());
-    dres->ClearTask("/job:worker/replica:0/task:0");
-    EXPECT_EQ(4, it.second->attr_table().size());
-  }
-}
-
-TEST_F(DeviceResDistTest, DeviceIncarnationChangesOnFailure) {
-  constexpr int num_workers = 3;
-  constexpr int num_devices = 4;
-  constexpr int failing_worker_index = 1;
-  const string device_type = "CPU";
-  constexpr uint64 device_incarnation_base = 100;
-  DefineWorkers(num_workers, num_devices, device_type, device_incarnation_base);
-  const string worker_prefix = "/job:worker/replica:0/task:";
-  const string failing_worker =
-      absl::StrCat(worker_prefix, failing_worker_index);
-
-  // Check device incarnations match expected.
-  std::vector<std::vector<uint64>> expected_incarnations(num_workers);
-  for (int w = 0; w < num_workers; ++w) {
-    expected_incarnations[w].resize(num_devices);
-    for (int d = 0; d < num_devices; ++d) {
-      expected_incarnations[w][d] =
-          w * num_devices + d + device_incarnation_base;
-    }
-  }
-  ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix,
-                                 device_type, expected_incarnations);
-
-  // Restart worker `failing_worker`.
-  constexpr uint64 restart_incarnation_base = 200;
-  RestartWorker(failing_worker, device_type, num_devices,
-                restart_incarnation_base);
-  for (int d = 0; d < num_devices; ++d) {
-    expected_incarnations[failing_worker_index][d] =
-        d + restart_incarnation_base;
-  }
-
-  // Check incarnations have changed for `failing worker`.
-  ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix,
-                                 device_type, expected_incarnations);
 }
 
 }  // namespace
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index c3ed312..1ff1032 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -278,9 +278,8 @@
   opts.config = request->server_def().default_session_config();
   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
       opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
-      request->lazy_copy_remote_function_inputs(), device_mgr, false, r,
-      GetDefaultCustomKernelCreator(), worker_session->cluster_flr());
+      request->async(), request->lazy_copy_remote_function_inputs(), device_mgr,
+      false, r, GetDefaultCustomKernelCreator(), worker_session->cluster_flr());
   // Ownership will be transferred to the ServerContext, or else in an error
   // case ctx will be deleted by this unref.
   core::ScopedUnref unref_ctx(ctx);
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index 700cea1..6c9149c 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -1218,7 +1218,7 @@
   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
       SessionOptions(),
       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false,
+      /*async=*/false,
       /*lazy_copy_function_remote_inputs=*/false, device_mgr_.get(), false,
       rendezvous, GetDefaultCustomKernelCreator());
   const uint64 context_id = random::New64();
diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc
index 9544906..901e757 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc
@@ -54,7 +54,7 @@
     ctx_ = new tensorflow::EagerContext(
         SessionOptions(),
         tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-        tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /*async=*/false,
+        /*async=*/false,
         /*lazy_copy_function_remote_inputs=*/false, device_mgr.release(), true,
         rendezvous, GetDefaultCustomKernelCreator(), nullptr);
   }
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index 6e70617..d529abe 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -35,11 +35,10 @@
 #include "tensorflow/core/platform/tracing.h"
 #include "tensorflow/core/protobuf/transport_options.pb.h"
 #include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/util/env_var.h"
 
 namespace tensorflow {
 
-const int kMaxWorkerRpcRetries = 10;
-
 class GrpcRemoteWorker : public WorkerInterface {
  public:
   explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
@@ -274,7 +273,7 @@
                     bool fail_fast = true) {
     new RPCState<protobuf::Message>(
         &stub_, cq_, method, *request, response, std::move(done), call_opts,
-        callback_threadpool_, /*max_retries=*/0, fail_fast, &target_);
+        callback_threadpool_, MaxRetries(), fail_fast, &target_);
   }
 
   void IssueRequest(const protobuf::Message* request, TensorResponse* response,
@@ -282,7 +281,7 @@
                     CallOptions* call_opts = nullptr) {
     new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
                                  std::move(done), call_opts,
-                                 callback_threadpool_, /*max_retries=*/0,
+                                 callback_threadpool_, MaxRetries(),
                                  /*fail_fast=*/true, &target_);
   }
 
@@ -299,6 +298,14 @@
   // Helper function for initializing the RpcMethod objects below.
   const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
 
+  // Helper function for configuring max GRPC retries. Defaults to 0 (no
+  // retries).
+  const int64 MaxRetries() {
+    int64 max_retries = -1;
+    TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
+    return max_retries;
+  }
+
   SharedGrpcChannelPtr channel_;
   ::grpc::GenericStub stub_;
   ::grpc::CompletionQueue* cq_;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 83e0725..71be10f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -249,7 +249,7 @@
                         .release();
   eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
 
-  profiler_service_ = CreateProfilerService();
+  profiler_service_ = profiler::CreateProfilerService();
   builder.RegisterService(profiler_service_.get());
 
   // extra service:
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
index 4fbc4bb..62a67b5 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -47,8 +47,9 @@
 
 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
   CollectiveRemoteAccessDistributed* rma =
-      new CollectiveRemoteAccessDistributed(
-          dev_mgr_, dev_resolver_.get(), work_queue_, worker_cache_, step_id);
+      new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
+                                            work_queue_, worker_cache_, step_id,
+                                            task_name_);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_, work_queue_);
 }
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index da606bb..651b487 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -1097,9 +1097,9 @@
     ],
 )
 
-filegroup(
+cc_library(
     name = "pywrap_required_hdrs",
-    srcs = [
+    textual_hdrs = [
         "op_gen_lib.h",
         "rendezvous.h",
     ],
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 05eefed..72e0b3d 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -161,18 +161,15 @@
       std::vector<DeviceAttributes>* attributes,
       const StatusCallback& done) = 0;
 
-  // Populate *attributes with the DeviceAttributes of the specified
-  // device.
+  // Populates *attributes with the DeviceAttributes of the specified device.
   virtual void GetDeviceAttributesAsync(const string& device,
                                         const string& task,
                                         DeviceAttributes* attributes,
                                         const StatusCallback& done) = 0;
 
-  // Clear the cache of device data belonging to the specified task.
-  virtual void ClearTask(const string& task) = 0;
-
-  // Clear the cache of all device data.
-  virtual void ClearCache() = 0;
+  // Returns the cached device attributes of a task.
+  virtual Status GetTaskCached(const string& task,
+                               std::vector<DeviceAttributes>* attributes) = 0;
 };
 
 // Interface that provides resolution of shared CollectiveParams fields.
@@ -279,6 +276,12 @@
                           const DeviceLocality& client_locality,
                           const StatusCallback& done) = 0;
 
+  // Checks the health of a collective peer. It probes the peer to see if it is
+  // alive. Note that if a peer has restarted, it's considered a different one,
+  // so CheckPeerHealth fails.
+  virtual void CheckPeerHealth(const string& peer_task,
+                               const StatusCallback& done) = 0;
+
   virtual BufRendezvous* buf_rendezvous() = 0;
 
   virtual void StartAbort(const Status& s) = 0;
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 35186f9..8c35b19 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -1188,20 +1188,6 @@
       registrar__body__##ctr##__object(op_name)
 
 }  // namespace data
-
-// TODO(b/114112161): Remove these aliases when all users have moved over to the
-// `tensorflow::data` namespace.
-using data::DatasetBase;
-using data::DatasetContext;
-using data::DatasetIterator;
-using data::DatasetOpKernel;
-using data::IteratorBase;
-using data::IteratorContext;
-using data::IteratorStateReader;
-using data::IteratorStateWriter;
-using data::SerializationContext;
-using data::UnaryDatasetOpKernel;
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index ebf06c7..564290b 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1454,6 +1454,12 @@
   return Status::OK();
 }
 
+void FunctionLibraryDefinition::Clear() {
+  mutex_lock l(mu_);
+  function_defs_.clear();
+  func_grad_.clear();
+}
+
 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
   const auto& i = func_grad_.find(func);
   if (i == func_grad_.end()) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 3c7c09e..3c04816 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -403,6 +403,9 @@
   // are no longer in use.
   Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_);
 
+  // Removes all the functions and gradient functions.
+  void Clear() TF_LOCKS_EXCLUDED(mu_);
+
   // Adds the functions and gradients in 'other' to this function library.
   // Duplicate functions and gradients are ignored.
   // This operation is atomic.
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index a62acfe..38ab8be 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -1068,6 +1068,16 @@
   EXPECT_FALSE(lib_def.Contains("XTimesTwo"));
 }
 
+TEST(FunctionLibraryDefinitionTest, Clear) {
+  FunctionLibraryDefinition lib_def(OpRegistry::Global(), {});
+  TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XTimesTwo()));
+  TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XAddX()));
+
+  lib_def.Clear();
+  EXPECT_FALSE(lib_def.Contains("XTimesTwo"));
+  EXPECT_FALSE(lib_def.Contains("XAddX"));
+}
+
 TEST(FunctionLibraryDefinitionTest, AddLibrary) {
   // Create lib def with single function
   FunctionDefLibrary proto;
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index bfa6e31..826c2ed 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -77,7 +77,14 @@
   Parameter(const string& name, std::shared_ptr<SharedState> state, double min,
             double max)
       : name(name),
-        value(state->value),
+        // Sometimes non-autotune nodes (with `autotune_=false`) may contain
+        // parameters (for example inputs of parallel interleave dataset which
+        // are not in the current cycle). To avoid unrealistic situation
+        // (say `buffer_size=-1` or `parallelism=-1`) in the optimization
+        // computation, if the state value is `kAutotune=-1` (just to indicate
+        // the `SharedState` is tunable), we initialize the parameter value to
+        // be the minimal value of the state.
+        value(state->value == kAutotune ? min : state->value),
         min(min),
         max(max),
         state(std::move(state)) {}
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index adc52d9..94b98d5 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -313,6 +313,7 @@
 #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
 #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
 #define REGISTER_OP_UNIQ(ctr, name)                                          \
+  TF_ATTRIBUTE_ANNOTATE("tf:op")                                             \
   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
       TF_ATTRIBUTE_UNUSED =                                                  \
           ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
@@ -326,6 +327,8 @@
 #define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \
   REGISTER_SYSTEM_OP_UNIQ(ctr, name)
 #define REGISTER_SYSTEM_OP_UNIQ(ctr, name)                                \
+  TF_ATTRIBUTE_ANNOTATE("tf:op")                                          \
+  TF_ATTRIBUTE_ANNOTATE("tf:op:system")                                   \
   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
       TF_ATTRIBUTE_UNUSED =                                               \
           ::tensorflow::register_op::OpDefBuilderWrapper<true>(name)
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 3bfceda..0116a1f 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1457,6 +1457,7 @@
 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
   constexpr bool should_register_##ctr##__flag =                      \
       SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
+  TF_ATTRIBUTE_ANNOTATE("tf:kernel")                                  \
   static ::tensorflow::kernel_factory::OpKernelRegistrar              \
       registrar__body__##ctr##__object(                               \
           should_register_##ctr##__flag                               \
@@ -1479,6 +1480,8 @@
   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
 
 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)    \
+  TF_ATTRIBUTE_ANNOTATE("tf:kernel")                                     \
+  TF_ATTRIBUTE_ANNOTATE("tf:kernel:system")                              \
   static ::tensorflow::kernel_factory::OpKernelRegistrar                 \
       registrar__body__##ctr##__object(                                  \
           ::tensorflow::register_kernel::system::kernel_builder.Build(), \
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index dbe1030..70253f4 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -332,6 +332,11 @@
   friend class Tensor;
 };
 
+/// Outputs `TensorShapeBase` to `std::ostream`.
+inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) {
+  return os << ts.DebugString();
+}
+
 /// Represents the value of one dimension in a TensorShape.
 struct TensorShapeDim {
   explicit TensorShapeDim(int64 s) : size(s) {}
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 3c4c186..0c57362 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -113,6 +113,12 @@
 // through template parameter.
 static const char* const kMklEagerOpPrefix = "_MklEager";
 
+// Prefix that we add to TF op name to construct MKL op that does not
+// depend on layout propagation. It will be used in both Eager and graph
+// modes unless there is a reason to have additional op name with
+// _MklEager prefix.
+static const char* const kMklNativeOpPrefix = "_MklNative";
+
 // Get the name of Mkl op from original TensorFlow op
 // We prefix 'Mkl' to the original op to get Mkl op.
 inline string GetMklOpName(const string& name) {
@@ -125,6 +131,12 @@
   return string(kMklEagerOpPrefix) + name;
 }
 
+// Get the name of Mkl Native (does not depend on layout propagation) op
+// from original TensorFlow op.
+inline string GetMklNativeOpName(const string& name) {
+  return string(kMklNativeOpPrefix) + name;
+}
+
 #ifdef ENABLE_INTEL_MKL_BFLOAT16
 static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
   return port::TestCPUFeature(port::CPUFeature::AVX512F);
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index a52160b..647a7be 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -65,8 +65,10 @@
 constexpr char kSend[] = "_Send";
 constexpr char kBatchMatMul[] = "BatchMatMul";
 constexpr char kBatchMatMulV2[] = "BatchMatMulV2";
+constexpr char kOneHot[] = "OneHot";
 constexpr char kPack[] = "Pack";
 constexpr char kRank[] = "Rank";
+constexpr char kRange[] = "Range";
 constexpr char kShape[] = "Shape";
 constexpr char kShapeN[] = "ShapeN";
 constexpr char kSize[] = "Size";
@@ -86,6 +88,7 @@
 constexpr char kStridedSlice[] = "StridedSlice";
 constexpr char kSpaceToDepth[] = "SpaceToDepth";
 constexpr char kTranspose[] = "Transpose";
+constexpr char kTile[] = "Tile";
 constexpr char kMaxPool[] = "MaxPool";
 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
 constexpr char kAvgPool[] = "AvgPool";
@@ -471,8 +474,12 @@
                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
   device_cost_impl_.emplace(kFill,
                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
+  device_cost_impl_.emplace(kOneHot,
+                            wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
   device_cost_impl_.emplace(kPack,
                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
+  device_cost_impl_.emplace(kRange,
+                            wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
   device_cost_impl_.emplace(kSpaceToDepth,
                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
   device_cost_impl_.emplace(kSplit,
@@ -481,6 +488,8 @@
                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
   device_cost_impl_.emplace(kTranspose,
                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
+  device_cost_impl_.emplace(kTile,
+                            wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
   device_cost_impl_.emplace(kUnpack,
                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
 
@@ -1533,8 +1542,29 @@
   return total_output_size;
 }
 
+bool HasZeroDim(const OpInfo& op_info) {
+  for (int i = 0; i < op_info.inputs_size(); ++i) {
+    const auto& input = op_info.inputs(i);
+    for (int j = 0; j < input.shape().dim_size(); ++j) {
+      const auto& dim = input.shape().dim(j);
+      if (dim.size() == 0) {
+        VLOG(1) << "Convolution config has zero dim "
+                << op_info.ShortDebugString();
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 Costs OpLevelCostEstimator::PredictConv2D(const OpContext& op_context) const {
   const auto& op_info = op_context.op_info;
+  if (HasZeroDim(op_info)) {
+    Costs costs = Costs::ZeroCosts();
+    costs.inaccurate = true;
+    costs.num_ops_with_unknown_shapes = 1;
+    return costs;
+  }
   bool found_unknown_shapes = false;
   auto costs = PredictOpCountBasedCost(
       CountConv2DOperations(op_info, &found_unknown_shapes), op_info);
@@ -1546,6 +1576,12 @@
 Costs OpLevelCostEstimator::PredictConv2DBackpropInput(
     const OpContext& op_context) const {
   const auto& op_info = op_context.op_info;
+  if (HasZeroDim(op_info)) {
+    Costs costs = Costs::ZeroCosts();
+    costs.inaccurate = true;
+    costs.num_ops_with_unknown_shapes = true;
+    return costs;
+  }
   bool found_unknown_shapes = false;
   auto costs =
       PredictOpCountBasedCost(CountConv2DBackpropInputOperations(
@@ -1559,6 +1595,12 @@
 Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
     const OpContext& op_context) const {
   const auto& op_info = op_context.op_info;
+  if (HasZeroDim(op_info)) {
+    Costs costs = Costs::ZeroCosts();
+    costs.inaccurate = true;
+    costs.num_ops_with_unknown_shapes = true;
+    return costs;
+  }
   bool found_unknown_shapes = false;
   auto costs =
       PredictOpCountBasedCost(CountConv2DBackpropFilterOperations(
@@ -2326,7 +2368,7 @@
       CalculateTensorSize(op_context.op_info.inputs(0), &found_unknown_shapes);
   const int64 output_size =
       CalculateTensorSize(op_context.op_info.outputs(0), &found_unknown_shapes);
-  const int output_elements = CalculateTensorElementCount(
+  const int64 output_elements = CalculateTensorElementCount(
       op_context.op_info.outputs(0), &found_unknown_shapes);
 
   const auto half_pixel_centers =
@@ -2340,7 +2382,7 @@
   }
 
   // Compose cost of bilinear interpolation.
-  auto ops = 0;
+  int64 ops = 0;
 
 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
   const auto sub_cost_float = EIGEN_COST(scalar_difference_op<float>);
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 0b62251..41e8ee8 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -793,6 +793,39 @@
   EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
 }
 
+TEST_F(OpLevelCostEstimatorTest, InvalidConv2DConfig) {
+  // Convolution ops.
+  const std::vector<const std::string> conv_ops = {
+      "Conv2D",
+      "Conv2DBackpropFilter",
+      "Conv2DBackpropInput",
+      "DepthwiseConv2dNative",
+      "DepthwiseConv2dNativeBackpropFilter",
+      "DepthwiseConv2dNativeBackpropInput",
+  };
+  // A valid Conv2D config.
+  const std::vector<int> valid_conv_config = {16, 19, 19, 48, 48, 5, 5, 256};
+  for (const auto& op : conv_ops) {
+    // Test with setting one value in conv config to zero.
+    // PredictCosts() should return zero costs.
+    for (int i = 0; i < valid_conv_config.size(); ++i) {
+      std::vector<int> conv_config(valid_conv_config);
+      conv_config[i] = 0;
+      auto op_context = DescribeConvolution(
+          conv_config[0], conv_config[1], conv_config[2], conv_config[3],
+          conv_config[4], conv_config[5], conv_config[6], conv_config[7]);
+      op_context.op_info.set_op(op);
+      auto cost = PredictCosts(op_context);
+      EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+      EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+      EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+      EXPECT_EQ(1, cost.num_ops_total);
+      EXPECT_TRUE(cost.inaccurate);
+      EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
+    }
+  }
+}
+
 TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
   auto cost =
       PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
@@ -1951,10 +1984,11 @@
   std::vector<std::string> reshape_ops = {
       "ConcatV2",     "DataFormatVecPermute",
       "DepthToSpace", "ExpandDims",
-      "Fill",         "Pack",
+      "Fill",         "OneHot",
+      "Pack",         "Range",
       "SpaceToDepth", "Split",
       "Squeeze",      "Transpose",
-      "Unpack"};
+      "Tile",         "Unpack"};
 
   const int kTensorSize = 1000;
   for (auto reshape_op : reshape_ops) {
@@ -2068,6 +2102,41 @@
     EXPECT_FALSE(cost.inaccurate);
     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
   }
+
+  {
+    // Cost with very large tensor.
+    op_context.op_info.clear_outputs();
+    // Number of elements in tensor exceeds 2^32.
+    constexpr int64 kLargeOutputImageDim = 40000;
+    DescribeTensor4D(1, kLargeOutputImageDim, kLargeOutputImageDim,
+                     kChannelSize, op_context.op_info.add_outputs());
+    const int64 kInterpWeightCost = 12;
+    // Using half_pixel_centers.
+    AttrValue half_pixel_centers;
+    half_pixel_centers.set_b(true);
+    (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
+        half_pixel_centers;
+
+    const int64 num_ops =
+        kInterpWeightCost * (kLargeOutputImageDim * 2) +
+        kComputeLerpCost *
+            (kLargeOutputImageDim * kLargeOutputImageDim * kChannelSize);
+    const int64 expected_compute_time = std::ceil(
+        num_ops /
+        estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
+
+    const int64 expected_memory_time =
+        (kImageDim * kImageDim + kLargeOutputImageDim * kLargeOutputImageDim) *
+        4;
+
+    const auto cost = PredictCosts(op_context);
+    EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
+    EXPECT_EQ(cost.memory_time, Costs::Duration(expected_memory_time));
+    EXPECT_EQ(cost.execution_time,
+              Costs::Duration(expected_memory_time + expected_compute_time));
+    EXPECT_FALSE(cost.inaccurate);
+    EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
+  }
 }
 
 }  // end namespace grappler
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 6b961c1..491b6bb 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -334,6 +334,8 @@
 
 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
 
+bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }
+
 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
 
 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 1bf2672..871353e 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -99,6 +99,7 @@
 bool IsImag(const NodeDef& node);
 bool IsImmutableConst(const NodeDef& node);
 bool IsInvGrad(const NodeDef& node);
+bool IsLeakyRelu(const NodeDef& node);
 bool IsLess(const NodeDef& node);
 bool IsLessEqual(const NodeDef& node);
 bool IsLog(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 9d2925e..d187046 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -880,6 +880,7 @@
     deps = [
         ":remapper",
         "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:cc_ops_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index d595d2b..e44bed4 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1338,7 +1338,9 @@
     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
     const TensorProto& raw_val = input_node->attr().at("value").tensor();
     Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
-    CHECK(value->FromProto(raw_val));
+    CHECK(value->FromProto(raw_val))
+        << "Unable to make Tensor from proto for " << node.name()
+        << " with shape " << raw_val.tensor_shape().DebugString();
     inputs.emplace_back(value);
     total_inputs_size += value->TotalBytes();
   }
diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
index b772192..4d324ec 100644
--- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc
+++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
@@ -561,7 +561,8 @@
 Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
                      AutoShardPolicy policy, int64 num_replicas,
                      GraphDef* output) {
-  if (policy == AutoShardPolicy::OFF || (num_workers == 1 && index == 0)) {
+  if (policy == AutoShardPolicy::OFF ||
+      (policy == AutoShardPolicy::FILE && num_workers == 1 && index == 0)) {
     return Status::OK();
   }
 
diff --git a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc
index ee8f9e8..52f1ba5 100644
--- a/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc
+++ b/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.cc
@@ -89,8 +89,15 @@
   // `max_intra_op_parallelism` input
   *insert_node.mutable_input()->Add() = max_parallelism_value->name();
 
-  for (const auto& attr_name : {"output_types", "output_shapes"}) {
-    graph_utils::CopyAttribute(attr_name, *last_node, &insert_node);
+  // Set `output_types` and `output_shapes` attributes by copying the relevant
+  // attrs from the input node. If we fail to set the attributes, we abort the
+  // rewrite.
+  for (auto attr : {"output_shapes", "output_types"}) {
+    if (last_node->attr().find(attr) != last_node->attr().end()) {
+      graph_utils::CopyAttribute(attr, *last_node, &insert_node);
+    } else {
+      return Status::OK();
+    }
   }
 
   auto* added_node = graph.AddNode(std::move(insert_node));
diff --git a/tensorflow/core/grappler/optimizers/data/slack.cc b/tensorflow/core/grappler/optimizers/data/slack.cc
index 27915e2..211b53b 100644
--- a/tensorflow/core/grappler/optimizers/data/slack.cc
+++ b/tensorflow/core/grappler/optimizers/data/slack.cc
@@ -101,10 +101,9 @@
     return Status::OK();
   }
 
-  return errors::InvalidArgument(
-      "Encountered unsupported op \"", dataset_node->op(),
-      "\" when rewriting the input pipeline graph to use slack in its "
-      "final prefetch transformation.");
+  LOG(WARNING) << "Could not find a final `prefetch` in the input pipeline to "
+                  "which to introduce slack.";
+  return Status::OK();
 }
 
 Status Slack::OptimizeAndCollectStats(Cluster* cluster,
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 0c4c046..8f18dfd 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -678,18 +678,41 @@
     find_differentiable_functions(function.node_def());
   }
 
-  // Find functions that are formed by XLA and will be compiled later. We do it
-  // by looking for a function attribute in XlaLaunch ops. Grappler rewrites
-  // potentially can add nodes that are not supported by XLA, so we choose to
-  // skip such functions when we optimize function library.
+  // Find functions that will be compiled by XLA later
+  // We do it by looking for XlaLaunch ops that call functions,
+  // then depth first search down those functions to find transitive functions.
+  // Grappler rewrites can potentially add nodes that are
+  // not supported by XLA, so we choose to skip such functions when we optimize
+  // the function library.
   absl::flat_hash_set<string> xla_compiled_functions;
+  std::function<void(const string&)> find_all_functions;
+  find_all_functions = [&](const string& func) -> void {
+    // Ignore call cycles in the graph
+    if (xla_compiled_functions.contains(func)) return;
+    // Find func in the flib
+    const FunctionDef* func_def = flib.Find(func);
+    CHECK(func_def) << "not found: " << func;
+    // Mark function to be ignored by grappler
+    xla_compiled_functions.insert(func);
+    // Depth first search through the func for transitively called funcs
+    for (const NodeDef& node : func_def->node_def()) {
+      for (const auto attr : node.attr()) {
+        const AttrValue& attr_value = attr.second;
+        if (attr_value.has_func()) {
+          find_all_functions(attr_value.func().name());
+        }
+      }
+    }
+  };
 
-  const auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
+  auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
     NameAttrList function;
     for (const NodeDef& node : nodes) {
+      // Look only for XlaLaunch nodes that call a function
       if (!IsXlaLaunch(node)) continue;
       if (!GetNodeAttr(node, "function", &function).ok()) continue;
-      xla_compiled_functions.insert(function.name());
+      // Find all transitively called functions
+      find_all_functions(function.name());
     }
   };
 
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 46c7afb..ba2cd98 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -361,7 +361,12 @@
 }
 
 bool IsSupportedActivation(const NodeDef& node) {
+// Disable LeakyRelu temporarily before MKL PR is merged.
+#ifndef INTEL_MKL
+  return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
+#else
   return IsRelu(node) || IsRelu6(node) || IsElu(node);
+#endif  // !INTEL_MKL
 }
 
 inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
@@ -450,6 +455,14 @@
       IsInPreserveSet(ctx, bias_add_node_def))
     return false;
 
+  // Get the contraction node
+  const auto* contraction_node_view =
+      bias_add_node_view->GetRegularFanin(0).node_view();
+  const auto* contraction_node_def = contraction_node_view->node();
+
+  // Currently, only conv + bias + leakyrelu is enabled
+  if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
+
   // Check that data type and data format are supported on assigned device.
   const ContractionWithBiasAddAndActivation pattern{base.contraction,
                                                     base.bias_add, node_index};
@@ -719,6 +732,16 @@
     return false;
   }
 
+  // Get the contraction node
+  const auto* bias_add_node_view =
+      add_node_view->GetRegularFanin(base.port_id).node_view();
+  const auto* contraction_node_view =
+      bias_add_node_view->GetRegularFanin(0).node_view();
+  const auto* contraction_node_def = contraction_node_view->node();
+
+  // Currently, only conv + bias + add + leakyrelu is enabled
+  if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
+
   // We successfully found a Conv2D+BiasAdd+AddN+activation pattern.
   const ContractionWithBiasAndAddActivation pattern{
       base.contraction, base.bias_add, base.add, base.port_id, node_index};
@@ -919,7 +942,8 @@
   return false;
 }
 
-void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
+void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
+                          const NodeDef* activation = nullptr) {
   DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
 
   auto* attr = fused_conv2d->mutable_attr();
@@ -932,10 +956,16 @@
   (*attr)["dilations"] = src_attr.at("dilations");
   (*attr)["data_format"] = src_attr.at("data_format");
   (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
+  // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha
+  if (activation != nullptr && IsLeakyRelu(*activation)) {
+    auto& activation_attr = activation->attr();
+    (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
+  }
 }
 
 void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
-                                         NodeDef* fused_dw_conv2d) {
+                                         NodeDef* fused_dw_conv2d,
+                                         const NodeDef* activation = nullptr) {
   DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
       << "Input node must be a DepthwiseConv2dNative";
 
@@ -947,6 +977,11 @@
   (*attr)["padding"] = src_attr.at("padding");
   (*attr)["dilations"] = src_attr.at("dilations");
   (*attr)["data_format"] = src_attr.at("data_format");
+  // Copy LeakyRelu's attr alpha to FusedDepthwiseConv2d's attr leakyrelu_alpha
+  if (activation != nullptr && IsLeakyRelu(*activation)) {
+    auto& activation_attr = activation->attr();
+    (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
+  }
 }
 
 void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
@@ -1049,6 +1084,7 @@
   const NodeDef& contraction = graph->node(matched.contraction);
   const NodeDef& bias_add = graph->node(matched.bias_add);
   const NodeDef& activation = graph->node(matched.activation);
+
   VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
           << activation.op() << ":"
           << " activation=" << activation.name()
@@ -1064,7 +1100,8 @@
 
   if (IsConv2D(contraction)) {
     fused_op.set_op(kFusedConv2D);
-    CopyConv2DAttributes(contraction, &fused_op);
+    // leaky relu has a special attribute alpha
+    CopyConv2DAttributes(contraction, &fused_op, &activation);
   } else if (IsDepthwiseConv2dNative(contraction)) {
     fused_op.set_op(kFusedDepthwiseConv2dNative);
     CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
@@ -1202,7 +1239,7 @@
   fused_conv2d.add_input(fused_batch_norm.input(3));  // 4: mean
   fused_conv2d.add_input(fused_batch_norm.input(4));  // 5: variance
 
-  CopyConv2DAttributes(contraction, &fused_conv2d);
+  CopyConv2DAttributes(contraction, &fused_conv2d, &activation);
   SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm", activation.op()},
                        /*num_args=*/4, /*epsilon=*/matched.epsilon);
 
@@ -1284,7 +1321,7 @@
   fused_conv2d.add_input(add.input(1 - matched.port_id));
 
   CopyConv2DAttributes(contraction, &fused_conv2d);
-  SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", "Relu"}, 2);
+  SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", activation.op()}, 2);
 
   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
   Status status;
diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc
index f4bc5e3..4cc1332 100644
--- a/tensorflow/core/grappler/optimizers/remapper_test.cc
+++ b/tensorflow/core/grappler/optimizers/remapper_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/grappler/optimizers/remapper.h"
 
+#include "tensorflow/cc/ops/nn_ops_internal.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.h"
@@ -541,7 +542,7 @@
 TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
   using ::tensorflow::ops::Placeholder;
 
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
 
     auto input_shape = Placeholder::Shape({8, 32, 32, 3});
@@ -567,6 +568,13 @@
         return ops::Identity(fetch, ops::Relu6(activate, bias_add));
       } else if (activation == "Elu") {
         return ops::Identity(fetch, ops::Elu(activate, bias_add));
+        // Disable LeakyRelu temporarily before MKL PR is merged.
+#ifndef INTEL_MKL
+      } else if (activation == "LeakyRelu") {
+        auto attr = ops::internal::LeakyRelu::Alpha(0.5);
+        return ops::Identity(
+            fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
+#endif  // !INTEL_MKL
       }
 
       return ops::Identity(fetch, bias);
@@ -605,6 +613,12 @@
         ASSERT_EQ(fused_ops.size(), 2);
         EXPECT_EQ(fused_ops[0], "BiasAdd");
         EXPECT_EQ(fused_ops[1], activation);
+
+#ifndef INTEL_MKL
+        if (activation == "LeakyRelu") {
+          EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
+        }
+#endif  // !INTEL_MKL
         found++;
       }
     }
@@ -795,7 +809,7 @@
 TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
   using ops::Placeholder;
 
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
 
     auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
@@ -828,6 +842,13 @@
         return ops::Identity(fetch, ops::Relu6(activate, batch_norm.y));
       } else if (activation == "Elu") {
         return ops::Identity(fetch, ops::Elu(activate, batch_norm.y));
+        // Disable LeakyRelu temporarily before MKL PR is merged.
+#ifndef INTEL_MKL
+      } else if (activation == "LeakyRelu") {
+        auto attr = ops::internal::LeakyRelu::Alpha(0.5);
+        return ops::Identity(
+            fetch, ops::internal::LeakyRelu(activate, batch_norm.y, attr));
+#endif  // !INTEL_MKL
       }
 
       return ops::Identity(fetch, batch_norm.y);
@@ -874,6 +895,12 @@
         ASSERT_EQ(fused_ops.size(), 2);
         EXPECT_EQ(fused_ops[0], "FusedBatchNorm");
         EXPECT_EQ(fused_ops[1], activation);
+
+#ifndef INTEL_MKL
+        if (activation == "LeakyRelu") {
+          EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
+        }
+#endif  // !INTEL_MKL
         found++;
       }
     }
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 581109b..d3f2d47 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -243,7 +243,6 @@
     srcs = ["collective_nccl_test.cc"],
     tags = tf_cuda_tests_tags() + [
         "guitar",
-        "manual",
         "multi_gpu",
         "no_oss",
         "notap",
@@ -1662,6 +1661,7 @@
         ":ops_testutil",
         ":ops_util",
         "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:cc_ops_internal",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
@@ -1672,6 +1672,7 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
         "//tensorflow/core/kernels/image",
+        "//tensorflow/core/platform:tf32_utils",
         "@com_google_absl//absl/algorithm:container",
     ],
 )
@@ -3087,7 +3088,6 @@
         ":logging_ops",
         ":summary_audio_op",
         ":summary_image_op",
-        ":summary_op",
         ":summary_tensor_op",
     ],
 )
@@ -3098,6 +3098,10 @@
     "//tensorflow/core:lib",
     "//tensorflow/core:lib_internal",
     "//tensorflow/core:protos_all_cc",
+    # TODO(b/162630222): remove this dependency.
+    "//tensorflow/c/kernels:histogram_summary_op",
+    "//tensorflow/c/kernels:merge_summary_op",
+    "//tensorflow/c/kernels:summary_op",
 ]
 
 tf_kernel_library(
@@ -3118,12 +3122,12 @@
     deps = LOGGING_DEPS + ["//tensorflow/core:png_internal"],
 )
 
-tf_kernel_library(
+# TODO(b/162630222): remove this target
+cc_library(
     name = "summary_op",
-    prefix = "summary_op",
-    deps = LOGGING_DEPS + [
-        # TODO(b/162630222): remove these dependencies.
+    deps = [
         "//tensorflow/c/kernels:histogram_summary_op",
+        "//tensorflow/c/kernels:merge_summary_op",
         "//tensorflow/c/kernels:summary_op",
     ],
 )
@@ -5910,6 +5914,7 @@
         "avgpooling_op.h",
         "batch_matmul_op_impl.h",
         "batch_norm_op.h",
+        "bincount_op.h",
         "broadcast_to_op.h",
         "bucketize_op.h",
         "control_flow_ops.h",
@@ -6127,6 +6132,7 @@
         ":android_extended_ops_headers",
         "base64_ops.cc",
         "batchtospace_op.cc",
+        "bincount_op.cc",
         "broadcast_to_op.cc",
         "bucketize_op.cc",
         "ctc_decoder_ops.cc",
@@ -6217,7 +6223,6 @@
         "string_split_op.cc",
         "string_to_hash_bucket_op.cc",
         "substr_op.cc",
-        "summary_op.cc",
         "tensor_array.cc",
         "tensor_array_ops.cc",
         "tensor_list.cc",
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index b137413..94ba4d8 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -126,7 +126,6 @@
     return GpuLaunchKernel(BincountReduceKernel<Tidx, T>, config.block_count,
                            config.thread_per_block, 0, d.stream(), arr.data(),
                            output.data(), nthreads, num_bins);
-    return Status::OK();
   }
 };
 
@@ -215,14 +214,11 @@
           config.block_count, config.thread_per_block, smem_usage, d.stream(),
           in.data(), weights.data(), weights.size(), out.data(), num_rows,
           num_cols, num_bins);
-    } else {
-      return GpuLaunchKernel(
-          BincountColReduceKernel<Tidx, T, binary_count>, config.block_count,
-          config.thread_per_block, 0, d.stream(), in.data(), weights.data(),
-          weights.size(), out.data(), num_rows, num_cols, num_bins);
     }
-
-    return Status::OK();
+    return GpuLaunchKernel(
+        BincountColReduceKernel<Tidx, T, binary_count>, config.block_count,
+        config.thread_per_block, 0, d.stream(), in.data(), weights.data(),
+        weights.size(), out.data(), num_rows, num_cols, num_bins);
   }
 };
 
diff --git a/tensorflow/core/kernels/collective_nccl_test.cc b/tensorflow/core/kernels/collective_nccl_test.cc
index 00456d9..04a15a0 100644
--- a/tensorflow/core/kernels/collective_nccl_test.cc
+++ b/tensorflow/core/kernels/collective_nccl_test.cc
@@ -316,14 +316,14 @@
       // Run the all-reduce.
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
-      NcclReducer reducer;
+      auto* reducer = new NcclReducer();
       auto col_ctx = std::make_shared<CollectiveContext>(
           parent_->col_exec_, parent_->dev_mgr_.get(),
           /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
           /*input=*/&input_, /*output=*/&input_);
-      TF_CHECK_OK(reducer.InitializeCollectiveContext(col_ctx));
+      TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
       Notification note;
-      reducer.Run([this, &note](Status s) {
+      reducer->Run([this, &note](Status s) {
         status_ = s;
         note.Notify();
       });
@@ -332,6 +332,7 @@
         CHECK(output_.CopyFrom(*ctx.mutable_output(0), input_.shape()));
       }
 
+      reducer->Unref();
       op_params.op_device_context->Unref();
     }
 
@@ -346,15 +347,15 @@
       // Run broadcast.
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
-      NcclBroadcaster broadcaster;
+      auto* broadcaster = new NcclBroadcaster();
       auto col_ctx = std::make_shared<CollectiveContext>(
           parent_->col_exec_, parent_->dev_mgr_.get(),
           /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
           /*input=*/col_params_.is_source ? &input_ : nullptr,
           /*output=*/&input_);
-      TF_CHECK_OK(broadcaster.InitializeCollectiveContext(col_ctx));
+      TF_CHECK_OK(broadcaster->InitializeCollectiveContext(col_ctx));
       Notification note;
-      broadcaster.Run([this, &note](Status s) {
+      broadcaster->Run([this, &note](Status s) {
         status_ = s;
         note.Notify();
       });
@@ -363,6 +364,7 @@
         CHECK(output_.CopyFrom(input_, input_.shape()));
       }
 
+      broadcaster->Unref();
       op_params.op_device_context->Unref();
     }
 
@@ -385,20 +387,21 @@
       // Run gather.
       string exec_key =
           strings::StrCat(col_params_.instance.instance_key, ":0:0");
-      NcclGatherer gatherer;
+      auto* gatherer = new NcclGatherer();
       auto col_ctx = std::make_shared<CollectiveContext>(
           parent_->col_exec_, parent_->dev_mgr_.get(),
           /*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
           /*input=*/&input_,
           /*output=*/&output_);
-      TF_CHECK_OK(gatherer.InitializeCollectiveContext(col_ctx));
+      TF_CHECK_OK(gatherer->InitializeCollectiveContext(col_ctx));
       Notification note;
-      gatherer.Run([this, &note](Status s) {
+      gatherer->Run([this, &note](Status s) {
         status_ = s;
         note.Notify();
       });
       note.WaitForNotification();
 
+      gatherer->Unref();
       op_params.op_device_context->Unref();
     }
 
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 322da25..8f81113 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -1394,7 +1394,9 @@
         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
 
     const int device_id = stream->parent()->device_ordinal();
-    DataType dtype = context->input(0).dtype();
+    // To make sure the Conv3DBackpropInputV2 get the correct dtype, we infer
+    // the dtype from 2nd input, i.e., out_backprop.
+    DataType dtype = context->input(2).dtype();
     const ConvParameters conv_parameters = {
         dims.batch_size,
         dims.in_depth,
diff --git a/tensorflow/core/kernels/conv_ops_fused_impl.h b/tensorflow/core/kernels/conv_ops_fused_impl.h
index f838d05..43aa6c6 100644
--- a/tensorflow/core/kernels/conv_ops_fused_impl.h
+++ b/tensorflow/core/kernels/conv_ops_fused_impl.h
@@ -106,6 +106,21 @@
   template <typename OutputKernel>
   void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx,
                   const Tensor& input, const Tensor& filter, Tensor* output) {
+    // Wrap output_kernel into type erased function to reduce the number of
+    // unique template instantiations for Eigen Tensor contraction expressions.
+    using OutputKernelFn =
+        std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
+                           const Eigen::TensorContractionParams&, Eigen::Index,
+                           Eigen::Index, Eigen::Index, Eigen::Index)>;
+
+    OutputKernelFn output_kernel_fn =
+        [&output_kernel](
+            const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
+            const Eigen::TensorContractionParams& params, Eigen::Index i,
+            Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
+          output_kernel(output_mapper, params, i, j, num_rows, num_cols);
+        };
+
     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 &&
         row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) {
       int conv_width = 1;  // Width for the convolution step.
@@ -115,12 +130,12 @@
 
       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
-      functor::MatMulConvFunctor<CPUDevice, T, OutputKernel>()(
+      functor::MatMulConvFunctor<CPUDevice, T, OutputKernelFn>()(
           ctx->eigen_device<CPUDevice>(),
           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
-          dim_pair, output_kernel);
+          dim_pair, std::move(output_kernel_fn));
 
     } else if (filter.dim_size(0) == input.dim_size(1) &&
                filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 &&
@@ -132,29 +147,30 @@
 
       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
-      functor::MatMulConvFunctor<CPUDevice, T, OutputKernel>()(
+      functor::MatMulConvFunctor<CPUDevice, T, OutputKernelFn>()(
           ctx->eigen_device<CPUDevice>(),
           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
           input.shaped<T, 2>({input.dim_size(0), k}),
           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair,
-          output_kernel);
+          std::move(output_kernel_fn));
 
     } else {
       if (padding_ == EXPLICIT) {
-        functor::SpatialConvolution<CPUDevice, T, OutputKernel>()(
+        functor::SpatialConvolution<CPUDevice, T, OutputKernelFn>()(
             ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
             col_stride_, row_dilation_, col_dilation_,
             static_cast<int>(explicit_paddings_[2]),
             static_cast<int>(explicit_paddings_[3]),
             static_cast<int>(explicit_paddings_[4]),
-            static_cast<int>(explicit_paddings_[5]), output_kernel);
+            static_cast<int>(explicit_paddings_[5]),
+            std::move(output_kernel_fn));
       } else {
-        functor::SpatialConvolution<CPUDevice, T, OutputKernel>()(
+        functor::SpatialConvolution<CPUDevice, T, OutputKernelFn>()(
             ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
             col_stride_, row_dilation_, col_dilation_,
-            BrainPadding2EigenPadding(padding_), output_kernel);
+            BrainPadding2EigenPadding(padding_), std::move(output_kernel_fn));
       }
     }
   }
@@ -185,14 +201,26 @@
 
     BiasAddArgs<T> bias_add_args;
     if (BiasAddArgs<T>::IsSupported(fusion)) {
-      OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
+      if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
+        OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
+                                                &fusion_args.leakyrelu_alpha));
+      } else {
+        OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
+      }
     }
 
     FusedBatchNormArgs<T> fused_batch_norm_args;
     if (FusedBatchNormArgs<T>::IsSupported(fusion)) {
-      OP_REQUIRES_OK(context,
-                     InitFusedBatchNormArgs(context, fusion_args.epsilon,
-                                            &fused_batch_norm_args));
+      if (fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu) {
+        OP_REQUIRES_OK(context,
+                       InitFusedBatchNormArgs(context, fusion_args.epsilon,
+                                              &fused_batch_norm_args,
+                                              &fusion_args.leakyrelu_alpha));
+      } else {
+        OP_REQUIRES_OK(context,
+                       InitFusedBatchNormArgs(context, fusion_args.epsilon,
+                                              &fused_batch_norm_args));
+      }
     }
 
     LaunchFusedConv2DWithOutputKernel<T> conv2d(
@@ -215,6 +243,10 @@
         conv2d(WithBiasAddAndRelu6<T>(bias_add_args), context, input, filter,
                output);
         break;
+      case FusedComputationType::kBiasAddWithLeakyRelu:
+        conv2d(WithBiasAddAndLeakyRelu<T>(bias_add_args), context, input,
+               filter, output);
+        break;
       case FusedComputationType::kBiasAddWithElu:
         conv2d(WithBiasAddAndElu<T>(bias_add_args), context, input, filter,
                output);
@@ -234,6 +266,11 @@
                                              fused_batch_norm_args),
                context, input, filter, output);
         break;
+      case FusedComputationType::kFusedBatchNormWithLeakyRelu:
+        conv2d(WithFusedBatchNormAndLeakyRelu<T>(fusion_args.epsilon,
+                                                 fused_batch_norm_args),
+               context, input, filter, output);
+        break;
       case FusedComputationType::kFusedBatchNormWithElu:
         conv2d(WithFusedBatchNormAndElu<T>(fusion_args.epsilon,
                                            fused_batch_norm_args),
@@ -681,10 +718,12 @@
           {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
           {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
           {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
+          {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
           {FCT::kFusedBatchNorm, {"FusedBatchNorm"}},
           {FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}},
           {FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}},
           {FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}},
+          {FCT::kFusedBatchNormWithLeakyRelu, {"FusedBatchNorm", "LeakyRelu"}},
       };
     }
 
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index 3e192b8..e8e156b 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/cc/ops/const_op.h"
 #include "tensorflow/cc/ops/image_ops.h"
 #include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/nn_ops_internal.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
 #include "tensorflow/core/framework/fake_input.h"
@@ -32,6 +33,7 @@
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tf32_utils.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 #include "tensorflow/core/public/session.h"
 
@@ -652,6 +654,8 @@
       ops::Relu6(root.WithOpName("with_activation"), with_bias);
     } else if (activation_type == "Elu") {
       ops::Elu(root.WithOpName("with_activation"), with_bias);
+    } else if (activation_type == "LeakyRelu") {
+      ops::internal::LeakyRelu(root.WithOpName("with_activation"), with_bias);
     } else {
       ops::Identity(root.WithOpName("with_activation"), with_bias);
     }
@@ -721,6 +725,9 @@
       ops::Relu6(root.WithOpName("with_activation"), with_fused_batch_norm.y);
     } else if (activation_type == "Elu") {
       ops::Elu(root.WithOpName("with_activation"), with_fused_batch_norm.y);
+    } else if (activation_type == "LeakyRelu") {
+      ops::internal::LeakyRelu(root.WithOpName("with_activation"),
+                               with_fused_batch_norm.y);
     } else {
       ops::Identity(root.WithOpName("with_activation"),
                     with_fused_batch_norm.y);
@@ -1038,9 +1045,10 @@
 #endif
 
 TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) {
+  tensorflow::allow_tf32_execution(false);  // Requires full precision Conv2D op
   const int filter_size = 1;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBiasAndActivation(activation, filter_size,
                                             filter_count);
   }
@@ -1049,7 +1057,7 @@
 TYPED_TEST_P(FusedConv2DWithBiasOpTest, ImageSizeConvolutionAndActivation) {
   const int filter_size = TestFixture::kImageWidth;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBiasAndActivation(activation, filter_size,
                                             filter_count);
   }
@@ -1058,7 +1066,7 @@
 TYPED_TEST_P(FusedConv2DWithBiasOpTest, SpatialConvolutionAndActivation) {
   const int filter_size = 3;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBiasAndActivation(activation, filter_size,
                                             filter_count);
   }
@@ -1069,7 +1077,7 @@
              ExplicitPaddingConvolutionAndActivation) {
   const int filter_size = 3;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBiasAndActivation(
         activation, filter_size, filter_count,
         /*explicit_paddings=*/{0, 0, 1, 2, 3, 4, 0, 0});
@@ -1112,7 +1120,7 @@
 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, OneByOneConvolutionAndActivation) {
   const int filter_size = 1;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size,
                                                  filter_count);
   }
@@ -1122,7 +1130,7 @@
              ImageSizeConvolutionAndActivation) {
   const int filter_size = TestFixture::kImageWidth;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size,
                                                  filter_count);
   }
@@ -1131,7 +1139,7 @@
 TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, SpatialConvolutionAndActivation) {
   const int filter_size = 3;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size,
                                                  filter_count);
   }
@@ -1142,7 +1150,7 @@
              ExplicitPaddingConvolutionAndActivation) {
   const int filter_size = 3;
   const int filter_count = 12;
-  for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+  for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
     this->VerifyConv2DWithBatchNormAndActivation(
         activation, filter_size, filter_count,
         /*explicit_paddings=*/{0, 0, 1, 2, 3, 4, 0, 0});
diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc
index 9262845..175cba3 100644
--- a/tensorflow/core/kernels/cwise_op_sigmoid.cc
+++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -17,8 +17,8 @@
 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
 
 namespace tensorflow {
-REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
-          complex64, complex128);
+REGISTER6(UnaryOp, CPU, "Sigmoid", functor::sigmoid, bfloat16, float,
+          Eigen::half, double, complex64, complex128);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half,
           double);
@@ -27,8 +27,8 @@
 REGISTER(UnaryOp, SYCL, "Sigmoid", functor::sigmoid, float);
 #endif  // TENSORFLOW_USE_SYCL
 
-REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float,
-          Eigen::half, double, complex64, complex128);
+REGISTER6(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, bfloat16,
+          float, Eigen::half, double, complex64, complex128);
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float,
           Eigen::half, double);
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 0066764..681bc1f 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -519,13 +519,6 @@
       return Status::OK();
     }
   }
-  for (const auto& node : fdef->node_def()) {
-    if (node.op() == kDataServiceDataset) {
-      return errors::InvalidArgument(
-          "The `.distribute(...)` dataset transformation is not supported "
-          "within tf.data functions.");
-    }
-  }
   return Status::OK();
 }
 
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index 68b3ea5..46e724c 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -264,11 +264,6 @@
 };
 
 }  // namespace data
-
-// TODO(b/114112161): Remove these aliases when all users have moved over to the
-// `tensorflow::data` namespace.
-using data::CapturedFunction;
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
index 1c35415..d893925 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
@@ -225,7 +225,7 @@
             [&]() {
               return dispatcher_->CreateJob(dataset()->dataset_id_,
                                             dataset()->processing_mode_,
-                                            &job_client_id_);
+                                            job_client_id_);
             },
             "create job", deadline_micros));
       } else {
@@ -233,7 +233,7 @@
             [&]() {
               return dispatcher_->GetOrCreateJob(
                   dataset()->dataset_id_, dataset()->processing_mode_,
-                  dataset()->job_name_, iterator_index_, &job_client_id_);
+                  dataset()->job_name_, iterator_index_, job_client_id_);
             },
             "get or create job", deadline_micros));
       }
@@ -347,7 +347,7 @@
       VLOG(3) << "Updating tasks";
       std::vector<TaskInfo> tasks;
       bool job_finished;
-      Status s = dispatcher_->GetTasks(job_client_id_, &tasks, &job_finished);
+      Status s = dispatcher_->GetTasks(job_client_id_, tasks, job_finished);
       if (!s.ok()) {
         LOG(WARNING) << "Failed to get task info for job client id "
                      << job_client_id_ << ": " << s;
@@ -382,7 +382,7 @@
         TaskInfo& task_info = new_task_entry.second;
         std::unique_ptr<DataServiceWorkerClient> worker;
         Status s = CreateDataServiceWorkerClient(task_info.worker_address(),
-                                                 dataset()->protocol_, &worker);
+                                                 dataset()->protocol_, worker);
         if (!s.ok()) {
           status_ = s;
           get_next_cv_.notify_all();
@@ -489,8 +489,8 @@
       CompressedElement compressed;
       bool end_of_sequence;
       for (int num_retries = 0;; ++num_retries) {
-        Status s = task->worker->GetElement(task->task_id, &compressed,
-                                            &end_of_sequence);
+        Status s = task->worker->GetElement(task->task_id, compressed,
+                                            end_of_sequence);
         if (s.ok()) {
           break;
         }
@@ -629,7 +629,7 @@
       ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
   ProcessingMode processing_mode;
   OP_REQUIRES_OK(ctx,
-                 ParseProcessingMode(processing_mode_str, &processing_mode));
+                 ParseProcessingMode(processing_mode_str, processing_mode));
 
   tstring address;
   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.cc b/tensorflow/core/kernels/data/experimental/data_service_ops.cc
index ba17581..91bc742 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_ops.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_ops.cc
@@ -63,7 +63,7 @@
   int64 deadline_micros = EnvTime::NowMicros() + kRetryTimeoutMicros;
   OP_REQUIRES_OK(
       ctx, grpc_util::Retry(
-               [&]() { return client.RegisterDataset(graph_def, &dataset_id); },
+               [&]() { return client.RegisterDataset(graph_def, dataset_id); },
                /*description=*/"register dataset", deadline_micros));
 
   Tensor* output;
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 24d4934..f7325cd 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -84,7 +84,7 @@
     // of the jobs, the experiments will be randomly turned on.
     // clang-format off
     absl::flat_hash_map<string, uint64> live_experiments = {
-        {"disable_intra_op_parallelism", 1}
+        {"disable_intra_op_parallelism", 5}
     };
     // clang-format on
     auto hash_func = [](const string& str) { return Hash64(str); };
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 7b64d9e..2bcc2b6 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -48,6 +48,14 @@
 #include "tensorflow/core/util/gpu_kernel_helper.h"
 #include "tensorflow/core/util/transform_output_iterator.h"
 
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+using stream_executor::cuda::ScopedActivateExecutorContext;
+#elif TENSORFLOW_USE_ROCM
+#include "tensorflow/core/platform/rocm.h"
+using stream_executor::rocm::ScopedActivateExecutorContext;
+#endif  // GOOGLE_CUDA
+
 namespace tensorflow {
 
 typedef Eigen::GpuDevice GPUDevice;
@@ -302,6 +310,9 @@
     TensorReference partition_ref(partition_count);
     auto wrapped_callback = [this, c, &data, &partitions, indices_out,
                              partition_ref, cpu_tensor, done]() {
+      auto stream = c->op_device_context()->stream();
+      ScopedActivateExecutorContext scoped_activation{stream->parent()};
+
       OpOutputList outputs;
       this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);
       if (!c->status().ok()) {
diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.cc b/tensorflow/core/kernels/fused_eigen_output_kernels.cc
index 94e621a..e8e9fd6 100644
--- a/tensorflow/core/kernels/fused_eigen_output_kernels.cc
+++ b/tensorflow/core/kernels/fused_eigen_output_kernels.cc
@@ -60,18 +60,25 @@
   if (*fused_computation == FusedComputationType::kBiasAdd ||
       *fused_computation == FusedComputationType::kBiasAddWithRelu ||
       *fused_computation == FusedComputationType::kBiasAddWithRelu6 ||
-      *fused_computation == FusedComputationType::kBiasAddWithElu) {
+      *fused_computation == FusedComputationType::kBiasAddWithElu ||
+      *fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) {
     if (num_args != 1) {
       return errors::InvalidArgument(
           "Fused ", kernel_name,
           " with BiasAdd must have one extra argument: bias.");
     }
+    if (*fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) {
+      TF_RETURN_IF_ERROR(context->GetAttr(
+          "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
+    }
   }
 
   if (*fused_computation == FusedComputationType::kFusedBatchNorm ||
       *fused_computation == FusedComputationType::kFusedBatchNormWithRelu ||
       *fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 ||
-      *fused_computation == FusedComputationType::kFusedBatchNormWithElu) {
+      *fused_computation == FusedComputationType::kFusedBatchNormWithElu ||
+      *fused_computation ==
+          FusedComputationType::kFusedBatchNormWithLeakyRelu) {
     if (num_args != 4) {
       return errors::InvalidArgument(
           "Fused ", kernel_name,
@@ -80,6 +87,11 @@
     }
     TF_RETURN_IF_ERROR(
         context->GetAttr("epsilon", &fused_computation_args->epsilon));
+    if (*fused_computation ==
+        FusedComputationType::kFusedBatchNormWithLeakyRelu) {
+      TF_RETURN_IF_ERROR(context->GetAttr(
+          "leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
+    }
   }
 
   return Status::OK();
diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.h b/tensorflow/core/kernels/fused_eigen_output_kernels.h
index 2588da1..546cf39 100644
--- a/tensorflow/core/kernels/fused_eigen_output_kernels.h
+++ b/tensorflow/core/kernels/fused_eigen_output_kernels.h
@@ -39,15 +39,18 @@
   kBiasAddWithRelu,
   kBiasAddWithRelu6,
   kBiasAddWithElu,
+  kBiasAddWithLeakyRelu,
   kFusedBatchNorm,
   kFusedBatchNormWithRelu,
   kFusedBatchNormWithRelu6,
-  kFusedBatchNormWithElu
+  kFusedBatchNormWithElu,
+  kFusedBatchNormWithLeakyRelu
 };
 
 // We have to pass around additional arguments for all possible fusion types.
 struct FusedComputationArgs {
-  float epsilon = 0.0;  // Used by `FusedBatchNorm` fusion only
+  float epsilon = 0.0;          // Used by `FusedBatchNorm` fusion only
+  float leakyrelu_alpha = 0.0;  // Used by `LeakyRelu` fusion only
 };
 
 struct FusedComputationPattern {
@@ -111,15 +114,32 @@
   };
 };
 
+// Applies `LeakyRelu` to the passed input expression.
+struct LeakyRelu {
+  template <typename XprType>
+  static auto apply(XprType expr, const float leakyrelu_alpha) -> decltype(
+      (expr < std::declval<typename XprType::Scalar>())
+          .select(expr *
+                      expr.constant(std::declval<typename XprType::Scalar>()),
+                  expr)) {
+    return (expr < static_cast<typename XprType::Scalar>(0))
+        .select(expr * expr.constant(static_cast<typename XprType::Scalar>(
+                           leakyrelu_alpha)),
+                expr);
+  };
+};
+
 template <typename T>
 struct BiasAddArgs {
   const T* bias_add_data = nullptr;
+  float leakyrelu_alpha;
 
   static bool IsSupported(FusedComputationType fusion) {
     return fusion == FusedComputationType::kBiasAdd ||
            fusion == FusedComputationType::kBiasAddWithRelu ||
            fusion == FusedComputationType::kBiasAddWithRelu6 ||
-           fusion == FusedComputationType::kBiasAddWithElu;
+           fusion == FusedComputationType::kBiasAddWithElu ||
+           fusion == FusedComputationType::kBiasAddWithLeakyRelu;
   }
 };
 
@@ -134,11 +154,14 @@
   //   scaling_factor = (estimated_variance + epsilon).rsqrt() * scale
   Eigen::Tensor<T, 1, Eigen::RowMajor> scaling_factor;
 
+  float leakyrelu_alpha;
+
   static bool IsSupported(FusedComputationType fusion) {
     return fusion == FusedComputationType::kFusedBatchNorm ||
            fusion == FusedComputationType::kFusedBatchNormWithRelu ||
            fusion == FusedComputationType::kFusedBatchNormWithRelu6 ||
-           fusion == FusedComputationType::kFusedBatchNormWithElu;
+           fusion == FusedComputationType::kFusedBatchNormWithElu ||
+           fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu;
   }
 };
 
@@ -203,6 +226,34 @@
   const T* bias_data;
 };
 
+template <typename T>
+struct BiasAddOutputKernel<T, LeakyRelu> {
+  explicit BiasAddOutputKernel(const BiasAddArgs<T>& args)
+      : bias_data(args.bias_add_data), leakyrelu_alpha(args.leakyrelu_alpha) {}
+
+  template <typename StorageIndex, typename Scalar>
+  EIGEN_ALWAYS_INLINE void operator()(
+      const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
+      const Eigen::TensorContractionParams& params, StorageIndex i,
+      StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
+    DCHECK(params.swapped_arguments);
+
+    const T* bias_base = bias_data + i;
+    typename TTypes<T>::UnalignedConstTensor bias(bias_base, num_rows);
+
+    for (int col = 0; col < num_cols; ++col) {
+      T* output_base = &output_mapper(0, col);
+      typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
+      const auto expr = output + bias;
+      output = LeakyRelu::template apply<decltype(expr)>(expr, leakyrelu_alpha);
+    }
+  }
+
+ private:
+  const T* bias_data;
+  float leakyrelu_alpha;
+};
+
 // Output kernel that fuses FusedBatchNorm operation into the output of tensor
 // contraction + activation function defined by Activation.
 template <typename T, typename Activation = Identity>
@@ -247,6 +298,51 @@
   const T* estimated_mean_data;
 };
 
+template <typename T>
+struct FusedBatchNormOutputKernel<T, LeakyRelu> {
+  FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs<T>& args)
+      : epsilon(epsilon),
+        scaling_factor_data(args.scaling_factor.data()),
+        offset_data(args.offset_data),
+        estimated_mean_data(args.estimated_mean_data),
+        leakyrelu_alpha(args.leakyrelu_alpha) {}
+
+  template <typename StorageIndex, typename Scalar>
+  EIGEN_ALWAYS_INLINE void operator()(
+      const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
+      const Eigen::TensorContractionParams& params, StorageIndex i,
+      StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
+    DCHECK(params.swapped_arguments);
+
+    const T* scaling_factor_base = scaling_factor_data + i;
+    const T* offset_base = offset_data + i;
+    const T* mean_base = estimated_mean_data + i;
+
+    typename TTypes<T>::UnalignedConstTensor scaling_factor(scaling_factor_base,
+                                                            num_rows);
+    typename TTypes<T>::UnalignedConstTensor offset(offset_base, num_rows);
+    typename TTypes<T>::UnalignedConstTensor mean(mean_base, num_rows);
+
+    for (int col = 0; col < num_cols; ++col) {
+      T* output_base = &output_mapper(0, col);
+      typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
+
+      auto scaled = (output - mean) * scaling_factor;
+      auto shifted = scaled + offset;
+
+      output = LeakyRelu::template apply<decltype(shifted)>(shifted,
+                                                            leakyrelu_alpha);
+    }
+  }
+
+ private:
+  T epsilon;
+  const T* scaling_factor_data;
+  const T* offset_data;
+  const T* estimated_mean_data;
+  float leakyrelu_alpha;
+};
+
 // Type aliases for the output kernels, purely for the sake of better launch
 // dispatching code readability.
 template <typename T>
@@ -258,6 +354,8 @@
 template <typename T>
 using WithBiasAddAndElu = BiasAddOutputKernel<T, Elu>;
 template <typename T>
+using WithBiasAddAndLeakyRelu = BiasAddOutputKernel<T, LeakyRelu>;
+template <typename T>
 using WithFusedBatchNorm = FusedBatchNormOutputKernel<T>;
 template <typename T>
 using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel<T, Relu>;
@@ -265,9 +363,12 @@
 using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel<T, Relu6>;
 template <typename T>
 using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel<T, Elu>;
+template <typename T>
+using WithFusedBatchNormAndLeakyRelu = FusedBatchNormOutputKernel<T, LeakyRelu>;
 
 template <typename T>
-Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs<T>* args) {
+Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs<T>* args,
+                       const float* leakyrelu_alpha = nullptr) {
   // Bias of the following dimensions: [ output_depth ]
   const Tensor& bias = context->input(2);
 
@@ -281,12 +382,17 @@
 
   args->bias_add_data = data_ptr(bias);
 
+  if (leakyrelu_alpha) {
+    args->leakyrelu_alpha = *leakyrelu_alpha;
+  }
+
   return Status::OK();
 }
 
 template <typename T>
 Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon,
-                              FusedBatchNormArgs<T>* args) {
+                              FusedBatchNormArgs<T>* args,
+                              const float* leakyrelu_alpha = nullptr) {
   const Tensor& scale = context->input(2);
   const Tensor& offset = context->input(3);
   const Tensor& estimated_mean = context->input(4);
@@ -319,6 +425,10 @@
       (estimated_variance.flat<T>() + static_cast<T>(epsilon)).rsqrt() *
       scale.flat<T>();
 
+  if (leakyrelu_alpha) {
+    args->leakyrelu_alpha = *leakyrelu_alpha;
+  }
+
   return Status::OK();
 }
 
diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD
index f7ad9ab..0d69a38 100644
--- a/tensorflow/core/kernels/image/BUILD
+++ b/tensorflow/core/kernels/image/BUILD
@@ -276,7 +276,7 @@
 tf_kernel_library(
     name = "resize_bilinear_op",
     prefix = "resize_bilinear_op",
-    deps = IMAGE_DEPS,
+    deps = IMAGE_DEPS + ["//tensorflow/core/kernels:cast_op"],
 )
 
 tf_kernel_library(
diff --git a/tensorflow/core/kernels/image/resize_bilinear_op.cc b/tensorflow/core/kernels/image/resize_bilinear_op.cc
index b9eb650..b84c7aa 100644
--- a/tensorflow/core/kernels/image/resize_bilinear_op.cc
+++ b/tensorflow/core/kernels/image/resize_bilinear_op.cc
@@ -16,6 +16,10 @@
 // See docs in ../ops/image_ops.cc
 #define EIGEN_USE_THREADS
 
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#define EIGEN_USE_GPU
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
 #include "tensorflow/core/kernels/image/resize_bilinear_op.h"
 
 #ifdef __SSE4_1__
@@ -30,6 +34,7 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/cast_op.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/util/image_resizer_state.h"
@@ -281,6 +286,25 @@
   }
 }
 
+template <typename Device>
+struct CastFloatToHalf {
+  void operator()(const Device& d, typename TTypes<float>::ConstFlat input,
+                  typename TTypes<Eigen::half>::Flat output) {
+    output.device(d) = input.template cast<Eigen::half>();
+  }
+};
+
+template <>
+struct CastFloatToHalf<GPUDevice> {
+  void operator()(const GPUDevice& d, typename TTypes<float>::ConstFlat input,
+                  typename TTypes<Eigen::half>::Flat output) {
+    // Use existing cast functor instead of directly casting Eigen tensor, as
+    // otherwise we need to instantiate the cast function in a .cu.cc file
+    functor::CastFunctor<GPUDevice, Eigen::half, float> cast;
+    cast(d, output, input);
+  }
+};
+
 }  // namespace
 
 // Partial specialization of ResizeBilinear functor for a CPUDevice.
@@ -355,11 +379,29 @@
     if (!context->status().ok()) return;
 
     TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
-    typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
 
-    functor::ResizeBilinearGrad<Device, T>()(
-        context->eigen_device<Device>(), input_grad, st.height_scale,
-        st.width_scale, half_pixel_centers_, output_grad);
+    if (!std::is_same<T, Eigen::half>::value) {
+      typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
+      functor::ResizeBilinearGrad<Device, T>()(
+          context->eigen_device<Device>(), input_grad, st.height_scale,
+          st.width_scale, half_pixel_centers_, output_grad);
+    } else {
+      // Accumulate output to float instead of half tensor, since float
+      // accumulation is more numerically stable and GPU half implementation is
+      // slow.
+      // TODO(b/165759037): Create optimized and numerically stable half
+      // implementation
+      Tensor output_grad;
+      OP_REQUIRES_OK(context, context->allocate_temp(
+                                  DT_FLOAT, st.output->shape(), &output_grad));
+      functor::ResizeBilinearGrad<Device, float>()(
+          context->eigen_device<Device>(), input_grad, st.height_scale,
+          st.width_scale, half_pixel_centers_, output_grad.tensor<float, 4>());
+      const Tensor& output_grad_const = output_grad;
+      CastFloatToHalf<Device>{}(context->template eigen_device<Device>(),
+                                output_grad_const.template flat<float>(),
+                                st.output->template flat<Eigen::half>());
+    }
   }
 
  private:
@@ -479,7 +521,7 @@
                               .HostMemory("size"),    \
                           ResizeBilinearOp<GPUDevice, T>);
 
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_KERNEL);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
 
 #undef REGISTER_KERNEL
 
@@ -488,7 +530,7 @@
       Name("ResizeBilinearGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
       ResizeBilinearOpGrad<GPUDevice, T>);
 
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GRAD_KERNEL);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GRAD_KERNEL);
 
 #undef REGISTER_GRAD_KERNEL
 
diff --git a/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc b/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc
index aa475a4..c8dfe75 100644
--- a/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc
@@ -442,13 +442,17 @@
   }
 };
 
-#define DEFINE_GPU_SPECS(T)                     \
-  template struct ResizeBilinear<GPUDevice, T>; \
+#define DEFINE_GPU_SPEC(T) template struct ResizeBilinear<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
+
+#define DEFINE_GRAD_GPU_SPEC(T) \
   template struct ResizeBilinearGrad<GPUDevice, T>;
 
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GRAD_GPU_SPEC);
 
-#undef DEFINE_GPU_SPECS
+#undef DEFINE_GPU_SPEC
+#undef DEFINE_GRAD_GPU_SPEC
 
 }  // namespace functor
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
index 7bffb5a..cb757ac 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.cc
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -175,7 +175,7 @@
     OP_REQUIRES_OK_ASYNC(
         ctx, GetInitializableLookupTable("table_handle", ctx, &table), done);
     core::ScopedUnref unref_me(table);
-    DatasetBase* dataset;
+    data::DatasetBase* dataset;
     OP_REQUIRES_OK_ASYNC(
         ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
     background_worker_.Schedule([ctx, dataset, table, done]() {
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc
index fc1e2fe..d07b525 100644
--- a/tensorflow/core/kernels/lookup_util.cc
+++ b/tensorflow/core/kernels/lookup_util.cc
@@ -396,12 +396,12 @@
 
 class DatasetIterator : public InitializableLookupTable::InitTableIterator {
  public:
-  explicit DatasetIterator(DatasetBase* dataset) : dataset_(dataset) {}
+  explicit DatasetIterator(data::DatasetBase* dataset) : dataset_(dataset) {}
 
   ~DatasetIterator() override {}
 
   Status Init(OpKernelContext* ctx) {
-    IteratorContext::Params params(ctx);
+    data::IteratorContext::Params params(ctx);
     function_handle_cache_ =
         absl::make_unique<data::FunctionHandleCache>(params.flr);
     params.function_handle_cache = function_handle_cache_.get();
@@ -409,7 +409,7 @@
     cancellation_manager_ =
         absl::make_unique<CancellationManager>(ctx->cancellation_manager());
     params.cancellation_manager = cancellation_manager_.get();
-    iterator_ctx_ = absl::make_unique<IteratorContext>(std::move(params));
+    iterator_ctx_ = absl::make_unique<data::IteratorContext>(std::move(params));
     TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(), nullptr,
                                               "LookupTable", &iterator_));
     Next();
@@ -442,12 +442,12 @@
   }
 
  private:
-  DatasetBase* dataset_;  // not owned.
-  std::unique_ptr<IteratorContext> iterator_ctx_;
+  data::DatasetBase* dataset_;  // not owned.
+  std::unique_ptr<data::IteratorContext> iterator_ctx_;
   std::unique_ptr<data::FunctionHandleCache> function_handle_cache_;
   ResourceMgr resource_mgr_;
   std::unique_ptr<CancellationManager> cancellation_manager_;
-  std::unique_ptr<IteratorBase> iterator_;
+  std::unique_ptr<data::IteratorBase> iterator_;
   std::vector<Tensor> tensors_;
   Status status_;
 };
diff --git a/tensorflow/core/kernels/map_kernels.cc b/tensorflow/core/kernels/map_kernels.cc
index 53299cc..4430fdf 100644
--- a/tensorflow/core/kernels/map_kernels.cc
+++ b/tensorflow/core/kernels/map_kernels.cc
@@ -41,6 +41,9 @@
 REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
                         TensorMapHasKey);
 
+REGISTER_KERNEL_BUILDER(Name("TensorMapStackKeys").Device(DEVICE_CPU),
+                        TensorMapStackKeys);
+
 #undef REGISTER_TENSOR_MAP_OPS_CPU
 
 #define REGISTER_TENSOR_MAP_OPS_CPU(T)
diff --git a/tensorflow/core/kernels/map_kernels.h b/tensorflow/core/kernels/map_kernels.h
index 0a1bd12..cf45db1 100644
--- a/tensorflow/core/kernels/map_kernels.h
+++ b/tensorflow/core/kernels/map_kernels.h
@@ -15,39 +15,37 @@
 #ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
 #define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
 
-#include <iostream>
-
 #include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/variant_encode_decode.h"
 #include "tensorflow/core/kernels/tensor_map.h"
+#include "tensorflow/core/util/batch_util.h"
 #include "tensorflow/core/util/tensor_ops_util.h"
 
 namespace tensorflow {
 
-Status GetInputMap(OpKernelContext* c, int index, const TensorMap** map) {
-  if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
+Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) {
+  if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) {
     return errors::InvalidArgument("Input map must be a scalar. Saw: ",
-                                   c->input(index).shape().DebugString());
+                                   ctx->input(index).shape().DebugString());
   }
-  const TensorMap* m = c->input(index).scalar<Variant>()().get<TensorMap>();
-  if (m == nullptr) {
+  const TensorMap* map = ctx->input(index).scalar<Variant>()().get<TensorMap>();
+  if (map == nullptr) {
     return errors::InvalidArgument(
         "Input handle is not a map. Saw: '",
-        c->input(index).scalar<Variant>()().DebugString(), "'");
+        ctx->input(index).scalar<Variant>()().DebugString(), "'");
   }
-  *map = m;
+  *ret_map = map;
   return Status::OK();
 }
 
 // TODO(kattian): change into templated function
-Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
+Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32 input_index,
                                   int32 output_index,
                                   const TensorMap& input_map,
                                   TensorMap** output_map) {
   // Attempt to forward the input tensor to the output if possible.
-  std::unique_ptr<Tensor> maybe_output = c->forward_input(
+  std::unique_ptr<Tensor> maybe_output = ctx->forward_input(
       input_index, output_index, DT_VARIANT, TensorShape{},
-      c->input_memory_type(input_index), AllocatorAttributes());
+      ctx->input_memory_type(input_index), AllocatorAttributes());
   Tensor* output_tensor;
   if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
       maybe_output->NumElements() == 1) {
@@ -60,7 +58,7 @@
     }
     if (tmp_out->RefCountIsOne()) {
       // Woohoo, forwarding succeeded!
-      c->set_output(output_index, *output_tensor);
+      ctx->set_output(output_index, *output_tensor);
       *output_map = tmp_out;
       return Status::OK();
     }
@@ -71,7 +69,7 @@
   AllocatorAttributes attr;
   attr.set_on_host(true);
   TF_RETURN_IF_ERROR(
-      c->allocate_output(output_index, {}, &output_tensor, attr));
+      ctx->allocate_output(output_index, {}, &output_tensor, attr));
   output_tensor->scalar<Variant>()() = input_map.Copy();
 
   *output_map = output_tensor->scalar<Variant>()().get<TensorMap>();
@@ -80,13 +78,13 @@
 
 class EmptyTensorMap : public OpKernel {
  public:
-  explicit EmptyTensorMap(OpKernelConstruction* c) : OpKernel(c) {}
+  explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {}
 
-  void Compute(OpKernelContext* c) override {
+  void Compute(OpKernelContext* ctx) override {
     Tensor* result;
     AllocatorAttributes attr;
     attr.set_on_host(true);
-    OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
     TensorMap empty;
     result->scalar<Variant>()() = std::move(empty);
   }
@@ -94,87 +92,136 @@
 
 class TensorMapSize : public OpKernel {
  public:
-  explicit TensorMapSize(OpKernelConstruction* c) : OpKernel(c) {}
+  explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {}
   ~TensorMapSize() override {}
 
-  void Compute(OpKernelContext* c) override {
-    const TensorMap* m = nullptr;
-    OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
+  void Compute(OpKernelContext* ctx) override {
+    const TensorMap* map = nullptr;
+    OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
     Tensor* result;
-    OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
-    result->scalar<int32>()() = m->tensors().size();
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
+    result->scalar<int32>()() = map->tensors().size();
   }
 };
 
 class TensorMapLookup : public OpKernel {
  public:
-  explicit TensorMapLookup(OpKernelConstruction* c) : OpKernel(c) {}
+  explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {}
   ~TensorMapLookup() override {}
 
-  void Compute(OpKernelContext* c) override {
-    const TensorKey& key = c->input(1);
-    const TensorMap* m = nullptr;
-    OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
+  void Compute(OpKernelContext* ctx) override {
+    const TensorKey& key = ctx->input(1);
+    const TensorMap* map = nullptr;
+    OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
 
-    OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(),
-                errors::InvalidArgument("Trying to lookup non-existent key."));
+    OP_REQUIRES(
+        ctx, map->tensors().find(key) != map->tensors().end(),
+        errors::InvalidArgument("Trying to lookup non-existent key. Could not "
+                                "find key \"" +
+                                key.SummarizeValue(100) + "\"."));
 
-    c->set_output(0, m->tensors().find(key)->second);
+    ctx->set_output(0, map->tensors().find(key)->second);
   }
 };
 
 class TensorMapInsert : public OpKernel {
  public:
-  explicit TensorMapInsert(OpKernelConstruction* c) : OpKernel(c) {}
+  explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {}
   ~TensorMapInsert() override {}
 
-  void Compute(OpKernelContext* c) override {
-    const TensorKey& key = c->input(1);
-    const Tensor& value = c->input(2);
-    const TensorMap* m = nullptr;
-    OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
+  void Compute(OpKernelContext* ctx) override {
+    const TensorKey& key = ctx->input(1);
+    const Tensor& value = ctx->input(2);
+    const TensorMap* map = nullptr;
+    OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
 
     TensorMap* output_map = nullptr;
-    OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map));
+    OP_REQUIRES_OK(ctx,
+                   ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
     output_map->replace(key, value);
   }
 };
 
 class TensorMapErase : public OpKernel {
  public:
-  explicit TensorMapErase(OpKernelConstruction* c) : OpKernel(c) {}
+  explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
 
-  void Compute(OpKernelContext* c) override {
-    const TensorKey& key = c->input(1);
-    const TensorMap* m = nullptr;
-    OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
+  void Compute(OpKernelContext* ctx) override {
+    const TensorKey& key = ctx->input(1);
+    const TensorMap* map = nullptr;
+    OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
 
-    OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(),
-                errors::InvalidArgument("Trying to erase non-existent item."));
+    OP_REQUIRES(
+        ctx, map->tensors().find(key) != map->tensors().end(),
+        errors::InvalidArgument("Trying to erase non-existent item. Could not "
+                                "find key \"" +
+                                key.SummarizeValue(100) + "\"."));
 
     TensorMap* output_map = nullptr;
-    OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map));
+    OP_REQUIRES_OK(ctx,
+                   ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
     output_map->tensors().erase(key);
   }
 };
 
 class TensorMapHasKey : public OpKernel {
  public:
-  explicit TensorMapHasKey(OpKernelConstruction* c) : OpKernel(c) {}
+  explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {}
   ~TensorMapHasKey() override {}
 
-  void Compute(OpKernelContext* c) override {
-    const TensorKey& key = c->input(1);
-    const TensorMap* m = nullptr;
-    OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
+  void Compute(OpKernelContext* ctx) override {
+    const TensorKey& key = ctx->input(1);
+    const TensorMap* map = nullptr;
+    OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
     Tensor* result;
-    OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
-    result->scalar<bool>()() = m->tensors().find(key) != m->tensors().end();
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
+    result->scalar<bool>()() = map->tensors().find(key) != map->tensors().end();
   }
 };
 
+class TensorMapStackKeys : public OpKernel {
+ public:
+  explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_));
+  }
+  ~TensorMapStackKeys() override {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const TensorMap* map = nullptr;
+    OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
+
+    OP_REQUIRES(ctx, map->size() != 0,
+                errors::InvalidArgument(
+                    "TensorMapStackKeys cannot be called on empty map."));
+
+    auto it = map->tensors().begin();
+    TensorShape output_shape = it->first.shape();
+    output_shape.InsertDim(0, map->tensors().size());
+    Tensor* result;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
+
+    int i = 0;
+    size_t sz = map->tensors().size();
+    TensorShape key_shape = it->first.shape();
+    while (it != map->tensors().end() && i < sz) {
+      OP_REQUIRES(
+          ctx, it->first.dtype() == key_dtype_,
+          errors::InvalidArgument("Key does not match requested dtype."));
+      OP_REQUIRES(
+          ctx, it->first.shape() == key_shape,
+          errors::InvalidArgument("Keys must all have the same shape."));
+      OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i));
+      i++;
+      it++;
+    }
+  }
+
+ private:
+  DataType key_dtype_;
+};
+
 template <typename Device>
-Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
+Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a,
                           const TensorMap& b, TensorMap* out) {
   // Binary add returns a map containing the union of keys.
   // Values with keys in the intersection are added.
@@ -185,7 +232,7 @@
     if (it != out->tensors().end()) {
       Tensor out_tensor;
       TF_RETURN_IF_ERROR(
-          BinaryAddTensors<Device>(c, p.second, it->second, &out_tensor));
+          BinaryAddTensors<Device>(ctx, p.second, it->second, &out_tensor));
       it->second = out_tensor;
     } else {
       out->tensors().emplace(p.first, p.second);
@@ -195,7 +242,7 @@
 }
 
 template <typename Device>
-Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x,
+Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x,
                           TensorMap* y) {
   // Zeros like returns an empty map.
   return Status::OK();
diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc
index 339ab93..c9bcdb5 100644
--- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc
@@ -59,6 +59,8 @@
   memory::dims diff_bias_dims;
   memory::dims diff_dst_dims;
   memory::dims strides;
+  MKL_TENSOR_FORMAT tf_fmt;
+  bool native_format;
   memory::dims dilations;
   memory::dims padding_left;
   memory::dims padding_right;
@@ -69,6 +71,7 @@
   MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims,
                          memory::dims diff_bias_dims,
                          memory::dims diff_dst_dims, memory::dims strides,
+                         MKL_TENSOR_FORMAT tf_fmt, bool native_format,
                          memory::dims dilations, memory::dims padding_left,
 #ifndef ENABLE_MKLDNN_V1
                          memory::dims padding_right, padding_kind padding)
@@ -80,6 +83,8 @@
         diff_bias_dims(diff_bias_dims),
         diff_dst_dims(diff_dst_dims),
         strides(strides),
+        tf_fmt(tf_fmt),
+        native_format(native_format),
         dilations(dilations),
         padding_left(padding_left),
 #ifndef ENABLE_MKLDNN_V1
@@ -243,15 +248,21 @@
   };
 
   void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
-    // Create memory descriptors for convolution backward filter without any
-    // specific format so that MKL-DNN can pick an appropriate one depending
-    // on the input parameters.
-    context_.src_md.reset(new memory::desc(
-        {convBwdFilterDims.src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
+    MEMORY_FORMAT user_data_fmt;
+    if (convBwdFilterDims.native_format) {
+      user_data_fmt =
+          MklTensorFormatToMklDnnDataFormat(convBwdFilterDims.tf_fmt);
+    } else {
+      // Create memory descriptors for convolution backward filter without any
+      // specific format so that MKL-DNN can pick an appropriate one depending
+      // on the input parameters.
+      user_data_fmt = MEMORY_FORMAT::any;
+    }
+    context_.src_md.reset(new memory::desc({convBwdFilterDims.src_dims},
+                                           MklDnnType<T>(), user_data_fmt));
 
-    context_.diff_dst_md.reset(
-        new memory::desc({convBwdFilterDims.diff_dst_dims}, MklDnnType<T>(),
-                         MEMORY_FORMAT::any));
+    context_.diff_dst_md.reset(new memory::desc(
+        {convBwdFilterDims.diff_dst_dims}, MklDnnType<T>(), user_data_fmt));
 
     context_.diff_filter_md.reset(
         new memory::desc({convBwdFilterDims.diff_filter_dims}, MklDnnType<T>(),
@@ -407,6 +418,9 @@
     key_creator.AddAsKey(convBwdFilterDims.dilations);
     key_creator.AddAsKey(convBwdFilterDims.padding_left);
     key_creator.AddAsKey(convBwdFilterDims.padding_right);
+    if (convBwdFilterDims.native_format) {
+      key_creator.AddAsKey(convBwdFilterDims.tf_fmt);
+    }
     return key_creator.GetKey();
   }
 
@@ -424,7 +438,7 @@
 };
 
 template <typename Device, class T, bool bias_enabled, bool is_depthwise,
-          bool eager_mode>
+          bool native_format>
 class MklConvCustomBackpropFilterOp
     : public MklConvBackpropCommonOp<Device, T, is_depthwise> {
  public:
@@ -441,9 +455,9 @@
       const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIdx);
 
       MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
-      GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
-      GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
-      GetMklShape(context, kDiffDstIdx, &diff_dst_mkl_shape, eager_mode);
+      GetMklShape(context, kInputIdx, &src_mkl_shape, native_format);
+      GetMklShape(context, kFilterIdx, &filter_mkl_shape, native_format);
+      GetMklShape(context, kDiffDstIdx, &diff_dst_mkl_shape, native_format);
       // Allow operator-specific sanity checking of shapes.
       ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
 
@@ -455,7 +469,7 @@
       TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
       TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
       TensorShape diff_dst_tf_shape =
-          GetTfShape(context, kDiffDstIdx, eager_mode);
+          GetTfShape(context, kDiffDstIdx, native_format);
 
       // Corner cases: output with 0 elements and 0 batch size.
       Tensor* diff_filter_tensor = nullptr;
@@ -469,7 +483,7 @@
         const int kOutputIdx = 0;
         AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
                                   diff_filter_tf_shape, diff_filter_mkl_shape,
-                                  eager_mode);
+                                  native_format);
         DCHECK(diff_filter_tensor != nullptr);
 
         // If output tensor has more than 0 elements, we need to 0 them out.
@@ -534,6 +548,7 @@
       for (int i = 0; i < dilations.size(); ++i) --dilations[i];
       MklConvBwdFilterParams convBwdFilterDims(
           fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
+          tf_fmt, native_format,
 #ifndef ENABLE_MKLDNN_V1
           dilations, padding_left, padding_right,
           TFPaddingToMklDnnPadding(this->padding_));
@@ -566,7 +581,7 @@
                diff_filter_dims[MklDnnDims::Dim_O]});
           AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
                                     diff_filter_tf_shape, diff_filter_mkl_shape,
-                                    eager_mode);
+                                    native_format);
         } else {
           // Depthwise Conv2d: diff_filter_dims is GOIHW format.
           //                  | TensorFlow       | MKLDNN
@@ -710,7 +725,7 @@
   TensorShape MakeInputTfShape(OpKernelContext* context,
                                const Tensor& input_tensor) {
     size_t input_idx = 0;
-    return GetTfShape(context, input_idx, eager_mode);
+    return GetTfShape(context, input_idx, native_format);
   }
 
   // Get TensorFlow shape of filter tensor.
@@ -792,7 +807,7 @@
           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
       MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>); \
   REGISTER_KERNEL_BUILDER(                                               \
-      Name("_MklEagerConv2DBackpropFilter")                              \
+      Name("_MklNativeConv2DBackpropFilter")                             \
           .Device(DEVICE_CPU)                                            \
           .TypeConstraint<T>("T")                                        \
           .Label(mkl_op_registry::kMklNameChangeOpLabel),                \
diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc
index 2e700d0..bfac57d 100644
--- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc
@@ -63,6 +63,8 @@
   memory::dims filter_dims;
   memory::dims diff_dst_dims;
   memory::dims strides;
+  MKL_TENSOR_FORMAT tf_fmt;
+  bool native_format;
   memory::dims dilations;
   memory::dims padding_left;
   memory::dims padding_right;
@@ -72,6 +74,7 @@
 
   MklConvBwdInputParams(memory::dims diff_src_dims, memory::dims filter_dims,
                         memory::dims diff_dst_dims, memory::dims strides,
+                        MKL_TENSOR_FORMAT tf_fmt, bool native_format,
                         memory::dims dilations, memory::dims padding_left,
 #ifndef ENABLE_MKLDNN_V1
                         memory::dims padding_right, padding_kind padding)
@@ -82,6 +85,8 @@
         filter_dims(filter_dims),
         diff_dst_dims(diff_dst_dims),
         strides(strides),
+        tf_fmt(tf_fmt),
+        native_format(native_format),
         dilations(dilations),
         padding_left(padding_left),
 #ifndef ENABLE_MKLDNN_V1
@@ -215,15 +220,22 @@
   };
 
   void Setup(const MklConvBwdInputParams& convBwdInputDims) {
-    // Create memory descriptors for conv bwd input without any specified
-    // format so that MKL-DNN can pick an appropriate one depending on the
-    // input parameters.
+    MEMORY_FORMAT user_data_fmt;
+    if (convBwdInputDims.native_format) {
+      user_data_fmt =
+          MklTensorFormatToMklDnnDataFormat(convBwdInputDims.tf_fmt);
+    } else {
+      // Create memory descriptors for conv bwd input without any specified
+      // format so that MKL-DNN can pick an appropriate one depending on the
+      // input parameters.
+      user_data_fmt = MEMORY_FORMAT::any;
+    }
+    context_.diff_dst_md.reset(new memory::desc(
+        {convBwdInputDims.diff_dst_dims}, MklDnnType<T>(), user_data_fmt));
     context_.diff_src_md.reset(new memory::desc(
-        {convBwdInputDims.diff_src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
+        {convBwdInputDims.diff_src_dims}, MklDnnType<T>(), user_data_fmt));
     context_.filter_md.reset(new memory::desc(
         {convBwdInputDims.filter_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
-    context_.diff_dst_md.reset(new memory::desc(
-        {convBwdInputDims.diff_dst_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
 
     // Create descriptors for both conv fwd and conv bwd input.
     context_.bwd_input_desc.reset(new ConvBwdDataDesc(
@@ -343,6 +355,9 @@
     key_creator.AddAsKey(convBwdInputDims.dilations);
     key_creator.AddAsKey(convBwdInputDims.padding_left);
     key_creator.AddAsKey(convBwdInputDims.padding_right);
+    if (convBwdInputDims.native_format) {
+      key_creator.AddAsKey(convBwdInputDims.tf_fmt);
+    }
     return key_creator.GetKey();
   }
 
@@ -358,7 +373,7 @@
   }
 };
 
-template <typename Device, class T, bool is_depthwise, bool eager_mode>
+template <typename Device, class T, bool is_depthwise, bool native_format>
 class MklConvCustomBackpropInputOp
     : public MklConvBackpropCommonOp<Device, T, is_depthwise> {
  public:
@@ -375,9 +390,9 @@
       const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
 
       MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
-      GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
-      GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
-      GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode);
+      GetMklShape(context, kInputIdx, &src_mkl_shape, native_format);
+      GetMklShape(context, kFilterIdx, &filter_mkl_shape, native_format);
+      GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, native_format);
       // Allow operator-specific sanity checking of shapes.
       ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
 
@@ -397,7 +412,7 @@
 
       TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
       TensorShape diff_dst_tf_shape =
-          GetTfShape(context, kOutbpropIdx, eager_mode);
+          GetTfShape(context, kOutbpropIdx, native_format);
 
       // Corner cases: output with 0 elements and 0 batch size.
       Tensor* diff_src_tensor = nullptr;
@@ -411,7 +426,7 @@
         const int kOutputIdx = 0;
         AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor,
                                   diff_src_tf_shape, diff_src_mkl_shape,
-                                  eager_mode);
+                                  native_format);
         DCHECK(diff_src_tensor != nullptr);
 
         // If output tensor has more than 0 elements, we need to 0 them out.
@@ -475,7 +490,8 @@
       // 0 in MKL-DNN.
       for (int i = 0; i < dilations.size(); ++i) --dilations[i];
       MklConvBwdInputParams convBwdInputDims(
-          fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations,
+          fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, tf_fmt,
+          native_format, dilations,
 #ifndef ENABLE_MKLDNN_V1
           padding_left, padding_right,
           TFPaddingToMklDnnPadding(this->padding_));
@@ -511,13 +527,11 @@
                                      bwd_diff_src_dims, bwd_diff_src_format);
       TensorShape diff_src_tf_shape;
       diff_src_tf_shape.AddDim(diff_src_pd.get_size() / sizeof(T));
-      Tensor tmp_tensor;
-      if (eager_mode) {
-        AllocTmpBuffer<T>(context, &tmp_tensor, diff_src_tf_shape);
+      if (native_format) {
         diff_src_tf_shape = diff_src_mkl_shape.GetTfShape();
       }
       AllocateOutputSetMklShape(context, 0, &diff_src_tensor, diff_src_tf_shape,
-                                diff_src_mkl_shape, eager_mode);
+                                diff_src_mkl_shape, native_format);
       T* diff_src_data =
           static_cast<T*>(const_cast<T*>(diff_src_tensor->flat<T>().data()));
 
@@ -555,29 +569,8 @@
       std::shared_ptr<stream> bwd_cpu_stream;
       bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine()));
       // Execute conv bwd input primitive.
-      if (!eager_mode) {
-        conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data,
-                                bwd_cpu_stream);
-      } else {
-        // In eager mode we first write the output to temporary
-        // buffer in MKL format. Then we convert the data to TF format.
-        T* tmp_data =
-            static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
-        conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data,
-                                bwd_cpu_stream);
-        auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
-#ifndef ENABLE_MKLDNN_V1
-        auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
-#endif
-        ReorderPd reorder_pd =
-            REORDER_PD_CONSTRUCTOR(diff_src_pd, OUTPUT_TF_MD, cpu_engine_);
-        memory* tmp_data_mem =
-            new MEMORY_CONSTRUCTOR(diff_src_pd, cpu_engine_, tmp_data);
-        memory* dst_data_mem =
-            new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data);
-        CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
-                                cpu_engine_, context);
-      }
+      conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data,
+                              bwd_cpu_stream);
 
       // Delete primitive since it is not cached.
       if (do_not_cache) {
@@ -625,7 +618,7 @@
   // Get TensorFlow shape of filter tensor.
   TensorShape MakeFilterTfShape(OpKernelContext* context,
                                 const Tensor& filter_tensor) {
-    return GetTfShape(context, kFilterIdx, eager_mode);
+    return GetTfShape(context, kFilterIdx, native_format);
   }
 
   // Get the Tensorflow shape of Output (diff_src),
@@ -683,7 +676,7 @@
           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),   \
       MklConvCustomBackpropInputOp<CPUDevice, T, false, false>); \
   REGISTER_KERNEL_BUILDER(                                       \
-      Name("_MklEagerConv2DBackpropInput")                       \
+      Name("_MklNativeConv2DBackpropInput")                      \
           .Device(DEVICE_CPU)                                    \
           .TypeConstraint<T>("T")                                \
           .Label(mkl_op_registry::kMklNameChangeOpLabel),        \
diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc
index 84fa20e..2caa244 100644
--- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc
@@ -65,6 +65,8 @@
   memory::dims dilations;
   memory::dims padding_left;
   memory::dims padding_right;
+  MKL_TENSOR_FORMAT tf_fmt;
+  bool native_format;
   string dtypes = string("");
   struct PostOpParam {
     string name;
@@ -77,7 +79,8 @@
   MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
                    memory::dims bias_dims, memory::dims dst_dims,
                    memory::dims strides, memory::dims dilations,
-                   memory::dims padding_left, memory::dims padding_right)
+                   memory::dims padding_left, memory::dims padding_right,
+                   MKL_TENSOR_FORMAT tf_fmt, bool native_format)
       : src_dims(src_dims),
         filter_dims(filter_dims),
         bias_dims(bias_dims),
@@ -85,7 +88,9 @@
         strides(strides),
         dilations(dilations),
         padding_left(padding_left),
-        padding_right(padding_right) {}
+        padding_right(padding_right),
+        tf_fmt(tf_fmt),
+        native_format(native_format) {}
 };
 
 // With quantization, input, filter, and output can have different types
@@ -228,15 +233,21 @@
   };
 
   void Setup(const MklConvFwdParams& convFwdDims) {
-    // Create memory descriptors for convolution data w/ no specified format
+    MEMORY_FORMAT user_data_fmt;
+    if (convFwdDims.native_format) {
+      user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
+    } else {
+      // Create memory descriptors for convolution data w/ no specified format
+      user_data_fmt = MEMORY_FORMAT::any;
+    }
     context_.src_md.reset(new memory::desc(
-        {convFwdDims.src_dims}, MklDnnType<Tinput>(), MEMORY_FORMAT::any));
+        {convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
 
     context_.filter_md.reset(new memory::desc(
         {convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
 
     context_.dst_md.reset(new memory::desc(
-        {convFwdDims.dst_dims}, MklDnnType<Toutput>(), MEMORY_FORMAT::any));
+        {convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
 
     if (!convFwdDims.bias_dims.empty())
       context_.bias_md.reset(new memory::desc(
@@ -414,6 +425,9 @@
     key_creator.AddAsKey(convFwdDims.padding_left);
     key_creator.AddAsKey(convFwdDims.padding_right);
     key_creator.AddAsKey(convFwdDims.dtypes);
+    if (convFwdDims.native_format) {
+      key_creator.AddAsKey(convFwdDims.tf_fmt);
+    }
 
     // Generate keys for post-ops
     for (auto const& post_op_param : convFwdDims.post_op_params) {
@@ -453,7 +467,7 @@
 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
           typename Toutput, typename Ttemp_output, typename Tpadding,
           bool bias_enabled, bool pad_enabled, bool is_depthwise,
-          bool eager_mode>
+          bool native_format>
 class MklConvOp : public OpKernel {
  public:
   ~MklConvOp() {}
@@ -525,8 +539,9 @@
       const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src);
       const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
       MklDnnShape src_mkl_shape, filter_mkl_shape;
-      GetMklShape(context, kInputIndex_Src, &src_mkl_shape, eager_mode);
-      GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, eager_mode);
+      GetMklShape(context, kInputIndex_Src, &src_mkl_shape, native_format);
+      GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape,
+                  native_format);
 
       OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(),
                   errors::InvalidArgument("Filter should not be in "
@@ -557,9 +572,9 @@
       // Get shapes of input tensors in MKL-DNN order
       MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
                               dilations_);
-      auto src_tf_shape = GetTfShape(context, kInputIndex_Src, eager_mode);
+      auto src_tf_shape = GetTfShape(context, kInputIndex_Src, native_format);
       auto filter_tf_shape =
-          GetTfShape(context, kInputIndex_Filter, eager_mode);
+          GetTfShape(context, kInputIndex_Filter, native_format);
       conv_utl.GetConvFwdSizesInMklOrder(
           src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
           &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
@@ -572,17 +587,16 @@
 
       // Corner cases: output with 0 elements and 0 batch size.
       Tensor* dst_tensor = nullptr;
-      Tensor tmp_tensor;
       bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) &&
                                  typeid(Tinput) == typeid(Toutput) &&
                                  (typeid(Tinput) == typeid(float) ||
                                   typeid(Tinput) == typeid(bfloat16))) &&
-                                !eager_mode;
+                                !native_format;
       if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
         MklDnnShape dst_mkl_shape;
         dst_mkl_shape.SetMklTensor(false);
         AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
-                                  src_tf_shape, dst_mkl_shape, eager_mode);
+                                  src_tf_shape, dst_mkl_shape, native_format);
 
         // MklConv2D/3D also outputs converted filter as 2nd output.
         filter_mkl_shape.SetMklTensor(false);
@@ -682,18 +696,19 @@
       }
       MklConvFwdParams convFwdDims(
           src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS,
-          dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
+          dst_dims_mkl_order, strides, dilations, padding_left, padding_right,
+          tf_fmt, native_format);
 
       // TODO(mdfaijul): Extend the basic parameters for data types and fusions
       this->ExtendConvFwdParams(context, convFwdDims);
       conv_fwd =
           MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get(
               convFwdDims, do_not_cache);
-      // Allocate output tensors `output_tensor` and `filter_out_tensor`
+      // Allocate output tensors `dst_tensor` and `filter_out_tensor`
       MklDnnShape output_mkl_shape;
       std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
       AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
-                           &output_mkl_shape, &dst_tensor, &tmp_tensor);
+                           &output_mkl_shape, &dst_tensor);
 
       Tensor* filter_out_tensor = nullptr;
       if (emit_filter_output) {
@@ -772,30 +787,7 @@
         conv_fwd->Execute(src_data, filter_data, bias_data, dst_data,
                           fwd_cpu_stream);
       } else {
-        if (!eager_mode) {
-          conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream);
-        } else {
-          // In eager mode we first write the output to temporary
-          // buffer in MKL format. Then we convert the data to TF format.
-          Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
-              tmp_tensor.flat<Toutput>().data());
-          conv_fwd->Execute(src_data, filter_data, tmp_data, fwd_cpu_stream);
-
-          // Now we need to convert the output to TF format.
-          auto output_tf_md = output_mkl_shape.GetTfLayout();
-#ifndef ENABLE_MKLDNN_V1
-          auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
-#endif  // !ENABLE_MKLDNN_V1
-          auto dst_pd = conv_fwd_pd->PRIMITIVE_DESC_DST;
-          ReorderPd reorder_pd =
-              REORDER_PD_CONSTRUCTOR(dst_pd, OUTPUT_TF_MD, cpu_engine_);
-          memory* tmp_data_mem =
-              new MEMORY_CONSTRUCTOR(dst_pd, cpu_engine_, tmp_data);
-          memory* dst_data_mem =
-              new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, dst_data);
-          CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
-                                  cpu_engine_, context);
-        }
+        conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream);
       }
 
       // Delete primitive since it is not cached.
@@ -911,8 +903,7 @@
                                     const memory::dims& output_dims_mkl_order,
                                     MKL_TENSOR_FORMAT output_tf_format,
                                     MklDnnShape* output_mkl_shape,
-                                    Tensor** output_tensor,
-                                    Tensor* tmp_tensor) {
+                                    Tensor** output_tensor) {
     DCHECK(output_tensor);
 #ifdef ENABLE_MKLDNN_V1
     auto dst_md = conv_prim_desc.dst_desc();
@@ -939,8 +930,7 @@
     // Allocate shape of TF tensor
     TensorShape output_tf_shape;
     output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
-    if (eager_mode) {
-      AllocTmpBuffer<Toutput>(context, tmp_tensor, output_tf_shape);
+    if (native_format) {
       output_tf_shape = output_mkl_shape->GetTfShape();
     }
 
@@ -957,7 +947,7 @@
       } else {
         AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
                                   output_tf_shape, *output_mkl_shape,
-                                  eager_mode);
+                                  native_format);
 #ifdef ENABLE_MKLDNN_V1
         auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
             output_mkl_shape->GetTfDataFormat());
@@ -991,7 +981,8 @@
       }
     } else {
       AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
-                                output_tf_shape, *output_mkl_shape, eager_mode);
+                                output_tf_shape, *output_mkl_shape,
+                                native_format);
     }
   }
 
@@ -1836,8 +1827,7 @@
                             const memory::dims& output_dims_mkl_order,
                             MKL_TENSOR_FORMAT output_tf_format,
                             MklDnnShape* output_mkl_shape,
-                            Tensor** output_tensor,
-                            Tensor* tmp_tensor) override {
+                            Tensor** output_tensor) override {
     int summand_idx = context->num_inputs() / 2 - 1;
     if (std::is_same<Toutput, quint8>::value) {
       summand_idx -= 2;
@@ -1869,7 +1859,7 @@
               false>::AllocateOutputTensor(context, conv_prim_desc,
                                            output_dims_mkl_order,
                                            output_tf_format, output_mkl_shape,
-                                           output_tensor, tmp_tensor);
+                                           output_tensor);
     const Tensor& summand = MklGetInput(context, summand_idx);
     if (summand.dtype() != DT_FLOAT)
       TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
@@ -2432,7 +2422,7 @@
           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
       MklDummyOp<CPUDevice, T>);                                               \
   REGISTER_KERNEL_BUILDER(                                                     \
-      Name("_MklEagerConv2D")                                                  \
+      Name("_MklNativeConv2D")                                                 \
           .Device(DEVICE_CPU)                                                  \
           .TypeConstraint<T>("T")                                              \
           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
index ad0712e..aa736ad 100644
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
@@ -175,8 +175,22 @@
     }
   }
 
+  // If the variant tensor input is empty, then we have no way to determine
+  // the correct shape for the dense_values.  (It must have rank>=1, and its
+  // outer dimension must be 0, but we don't know its shape beyond that.)
+  // For now, we just use a shape of `[0]` in this case.
+  // TODO(edloper): Update this op with an attribute containing information
+  // about dense_values shape.  If it's `None`, then we'll probably still have
+  // to use shape=[0] here, but if we have more info, then we can use it.
+  // E.g., in map_fn, we may have shape info from the RaggedTensorSpec.
+  TensorShape component_values_shape;
+  if (ragged_components.empty()) {
+    component_values_shape = TensorShape({0});
+  } else {
+    component_values_shape = ragged_components[0].values.shape();
+  }
+
   // Populate values.
-  TensorShape component_values_shape = ragged_components[0].values.shape();
   int values_size = component_values_shape.dim_size(0);
   for (int i = 1; i < ragged_components.size(); i++) {
     if (ragged_components[i].values.dims() != component_values_shape.dims()) {
diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc
deleted file mode 100644
index d5277c5..0000000
--- a/tensorflow/core/kernels/summary_op.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
-// inputs or outputs in various ways.
-
-// See docs in ../ops/summary_ops.cc.
-
-#include <unordered_set>
-
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/summary.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/histogram/histogram.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-
-class SummaryMergeOp : public OpKernel {
- public:
-  explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {}
-
-  void Compute(OpKernelContext* c) override {
-    Summary s;
-    std::unordered_set<string> tags;
-    for (int input_num = 0; input_num < c->num_inputs(); input_num++) {
-      const Tensor& in = c->input(input_num);
-      auto in_vec = in.flat<tstring>();
-      for (int i = 0; i < in_vec.dimension(0); i++) {
-        const string& s_in = in_vec(i);
-        Summary summary_in;
-        if (!ParseProtoUnlimited(&summary_in, s_in)) {
-          c->SetStatus(errors::InvalidArgument(
-              "Could not parse one of the summary inputs"));
-          return;
-        }
-
-        for (int v = 0; v < summary_in.value_size(); v++) {
-          const string& tag = summary_in.value(v).tag();
-          // The tag is unused by the TensorSummary op, so no need to check
-          // for duplicates.
-          if ((!tag.empty()) && !tags.insert(tag).second) {
-            c->SetStatus(errors::InvalidArgument(strings::StrCat(
-                "Duplicate tag ", tag, " found in summary inputs")));
-            return;
-          }
-          *s.add_value() = summary_in.value(v);
-        }
-      }
-    }
-
-    Tensor* summary_tensor = nullptr;
-    OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
-    CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()()));
-  }
-};
-
-REGISTER_KERNEL_BUILDER(Name("MergeSummary").Device(DEVICE_CPU),
-                        SummaryMergeOp);
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_map.h b/tensorflow/core/kernels/tensor_map.h
index d29d244..cb4c827 100644
--- a/tensorflow/core/kernels/tensor_map.h
+++ b/tensorflow/core/kernels/tensor_map.h
@@ -144,7 +144,19 @@
   size_t erase(TensorKey key) { return tensors_->values_.erase(key); }
 
   // Size returns the number of elements in the map
-  size_t size() { return tensors_->values_.size(); }
+  size_t size() const { return tensors_->values_.size(); }
+
+  std::vector<Tensor> keys() const {
+    std::vector<Tensor> keys;
+    keys.reserve(tensors_->values_.size());
+    absl::flat_hash_map<TensorKey, Tensor>::iterator it =
+        tensors_->values_.begin();
+    while (it != tensors_->values_.end()) {
+      keys.push_back(it->first);
+      it++;
+    }
+    return keys;
+  }
 
   // Is this TensorMap the only one with a reference to the underlying
   // container?
diff --git a/tensorflow/core/kernels/tensor_map_test.cc b/tensorflow/core/kernels/tensor_map_test.cc
index beaff6f..76c903f 100644
--- a/tensorflow/core/kernels/tensor_map_test.cc
+++ b/tensorflow/core/kernels/tensor_map_test.cc
@@ -14,7 +14,6 @@
 ==============================================================================*/
 
 #include "tensorflow/core/kernels/tensor_map.h"
-
 #include "absl/container/flat_hash_map.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
@@ -45,7 +44,6 @@
 }
 
 TEST(TensorMapTest, Insert) {
-  EXPECT_EQ(1, 1);
   TensorMap tm;
   TensorKey k = Tensor(11);
   Tensor v = Tensor(22);
@@ -102,12 +100,49 @@
   Tensor v1 = Tensor(22);
   Tensor v2 = Tensor(23);
   tm[k] = v2;
-
   absl::flat_hash_map<TensorKey, Tensor>::iterator map_it = tm.find(k);
   EXPECT_EQ(map_it->first, k);
   test::ExpectTensorEqual<int32>(map_it->second, v2);
 }
 
+TEST(TensorMapTest, ListKeys) {
+  TensorMap tm;
+  TensorKey k = Tensor(11.0);
+  TensorKey k2 = Tensor(12.0);
+  Tensor v = Tensor(22);
+  Tensor v2 = Tensor(23);
+  tm.insert(k, v);
+  tm.insert(k2, v2);
+  std::vector<Tensor> keys = tm.keys();
+
+  // Extract and sort double value for each key Tensor.
+  std::vector<std::pair<double, int>> key_doubles;
+  for (int i = 0; i < keys.size(); i++) {
+    double x = keys[i].scalar<double>()();
+    std::pair<double, int> p = std::pair<double, int>(x, i);
+    key_doubles.push_back(p);
+  }
+  sort(key_doubles.begin(), key_doubles.end());
+  // Check number of keys and each key.
+  EXPECT_EQ(keys.size(), 2);
+  EXPECT_EQ(key_doubles[0].first, 11.0);
+  EXPECT_EQ(key_doubles[1].first, 12.0);
+  // Check key shapes.
+  int ind1 = key_doubles[0].second;
+  int ind2 = key_doubles[1].second;
+  EXPECT_EQ(keys[ind1].shape(), k.shape());
+  EXPECT_EQ(keys[ind2].shape(), k2.shape());
+}
+
+TEST(TensorMapTest, Size) {
+  TensorMap tm;
+  EXPECT_EQ(tm.size(), 0);
+  TensorKey k = Tensor(11);
+  Tensor v = Tensor(22);
+  tm.insert(k, v);
+  EXPECT_EQ(tm.size(), 1);
+}
+
 TEST(TensorMapTest, Copy) {
   TensorMap tm;
   TensorKey k = Tensor(11);
diff --git a/tensorflow/core/lib/lmdb/BUILD b/tensorflow/core/lib/lmdb/BUILD
new file mode 100644
index 0000000..c863d4c
--- /dev/null
+++ b/tensorflow/core/lib/lmdb/BUILD
@@ -0,0 +1,28 @@
+# Description:
+# lmdb test data packages.
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "lmdb_testdata",
+    testonly = 1,
+    srcs = [
+        # A simple key-value store:
+        #   0 : 'b'
+        #   1 : 'b'
+        #    ...
+        #   9 : 'b'
+        # Which is then overwritten with:
+        #   0 : 'a'
+        #   1 : 'b'
+        #    ...
+        #   9 : 'j'
+        "testdata/data.mdb",
+        # LMDB, being a memory-mapped database, uses a different file format on
+        # big-endian systems.
+        "testdata/data_bigendian.mdb",
+    ],
+    visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/core/lib/psnr/BUILD b/tensorflow/core/lib/psnr/BUILD
new file mode 100644
index 0000000..386f1a5
--- /dev/null
+++ b/tensorflow/core/lib/psnr/BUILD
@@ -0,0 +1,15 @@
+package(
+    default_visibility = [
+        "//tensorflow/core:__pkg__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "testdata",
+    srcs = [
+        "testdata/cat_q20.jpg",
+        "testdata/cat_q72.jpg",
+        "testdata/cat_q95.jpg",
+    ],
+)
diff --git a/tensorflow/core/lib/ssim/BUILD b/tensorflow/core/lib/ssim/BUILD
new file mode 100644
index 0000000..7d9b72b
--- /dev/null
+++ b/tensorflow/core/lib/ssim/BUILD
@@ -0,0 +1,15 @@
+package(
+    default_visibility = [
+        "//tensorflow/core:__pkg__",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "testdata",
+    srcs = [
+        "testdata/checkerboard1.png",
+        "testdata/checkerboard2.png",
+        "testdata/checkerboard3.png",
+    ],
+)
diff --git a/tensorflow/core/ops/compat/BUILD b/tensorflow/core/ops/compat/BUILD
index 1b1aea3..47ab66c 100644
--- a/tensorflow/core/ops/compat/BUILD
+++ b/tensorflow/core/ops/compat/BUILD
@@ -32,11 +32,14 @@
 tf_cc_test(
     name = "backwards_compatibility_test",
     size = "small",
-    srcs = ["backwards_compatibility_test.cc"],
+    srcs = [
+        "backwards_compatibility_test.cc",
+    ],
     data = [
         "//tensorflow/core:ops/ops.pbtxt",
+        "//tensorflow/core/ops/compat/ops_history_v1:ops_history_v1_srcs",
+        "//tensorflow/core/ops/compat/ops_history_v2:ops_history_v2_srcs",
     ] + glob([
-        "ops_history_v*/*.pbtxt",
         "ops_history.v*.pbtxt",
     ]),
     tags = [
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BUILD b/tensorflow/core/ops/compat/ops_history_v1/BUILD
new file mode 100644
index 0000000..dfd7dab
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BUILD
@@ -0,0 +1,16 @@
+# Description:
+# Test for keeping the history of OpDefs for every major version of TensorFlow,
+# to validate that we don't make backwards-incompatible changes in particular
+# for v1.
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "ops_history_v1_srcs",
+    srcs = glob([
+        "*.pbtxt",
+    ]),
+    visibility = ["//tensorflow/core/ops/compat:__pkg__"],
+)
diff --git a/tensorflow/core/ops/compat/ops_history_v2/BUILD b/tensorflow/core/ops/compat/ops_history_v2/BUILD
new file mode 100644
index 0000000..a746280
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/BUILD
@@ -0,0 +1,16 @@
+# Description:
+# Test for keeping the history of OpDefs for every major version of TensorFlow,
+# to validate that we don't make backwards-incompatible changes in particular
+# for v2.
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "ops_history_v2_srcs",
+    srcs = glob([
+        "*.pbtxt",
+    ]),
+    visibility = ["//tensorflow/core/ops/compat:__pkg__"],
+)
diff --git a/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt
index 4c931cc..bf147ac 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/Max.pbtxt
@@ -32,8 +32,6 @@
         type: DT_UINT16
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
@@ -89,8 +87,6 @@
         type: DT_UINT16
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
@@ -148,8 +144,6 @@
         type: DT_UINT16
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
@@ -206,14 +200,12 @@
         type: DT_UINT8
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
         type: DT_INT64
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
         type: DT_BFLOAT16
         type: DT_UINT16
-        type: DT_COMPLEX128
         type: DT_HALF
         type: DT_UINT32
         type: DT_UINT64
@@ -234,3 +226,63 @@
     }
   }
 }
+op {
+  name: "Max"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt
index f0ebdb0..4959b5e 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/Min.pbtxt
@@ -32,8 +32,6 @@
         type: DT_UINT16
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
@@ -89,8 +87,6 @@
         type: DT_UINT16
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
@@ -148,8 +144,6 @@
         type: DT_UINT16
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
-        type: DT_COMPLEX128
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
@@ -206,14 +200,12 @@
         type: DT_UINT8
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
         type: DT_INT64
         type: DT_QINT8
         type: DT_QUINT8
         type: DT_QINT32
         type: DT_BFLOAT16
         type: DT_UINT16
-        type: DT_COMPLEX128
         type: DT_HALF
         type: DT_UINT32
         type: DT_UINT64
@@ -234,3 +226,63 @@
     }
   }
 }
+op {
+  name: "Min"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "reduction_indices"
+    type_attr: "Tidx"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "keep_dims"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_INT64
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "Tidx"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/TensorMapErase.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TensorMapErase.pbtxt
new file mode 100644
index 0000000..854e731
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/TensorMapErase.pbtxt
@@ -0,0 +1,23 @@
+op {
+  name: "TensorMapErase"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key"
+    type_attr: "key_dtype"
+  }
+  output_arg {
+    name: "output_handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+  attr {
+    name: "value_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/compat/ops_history_v2/TensorMapHasKey.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TensorMapHasKey.pbtxt
index 4378227..a095c36 100644
--- a/tensorflow/core/ops/compat/ops_history_v2/TensorMapHasKey.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v2/TensorMapHasKey.pbtxt
@@ -6,14 +6,14 @@
   }
   input_arg {
     name: "key"
-    type_attr: "element_dtype"
+    type_attr: "key_dtype"
   }
   output_arg {
     name: "has_key"
     type: DT_BOOL
   }
   attr {
-    name: "element_dtype"
+    name: "key_dtype"
     type: "type"
   }
 }
diff --git a/tensorflow/core/ops/compat/ops_history_v2/TensorMapStackKeys.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TensorMapStackKeys.pbtxt
new file mode 100644
index 0000000..c3befaa
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v2/TensorMapStackKeys.pbtxt
@@ -0,0 +1,15 @@
+op {
+  name: "TensorMapStackKeys"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "keys"
+    type_attr: "key_dtype"
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+}
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index acc5a54..ae9822c 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -116,12 +116,6 @@
     .SetShapeFn(shape_inference::ScalarShape)
     .Deprecated(15, "Use AudioSummaryV2.");
 
-REGISTER_OP("MergeSummary")
-    .Input("inputs: N * string")
-    .Output("summary: string")
-    .Attr("N : int >= 1")
-    .SetShapeFn(shape_inference::ScalarShape);
-
 REGISTER_OP("Timestamp")
     .Output("ts: float64")
     .SetIsStateful()
diff --git a/tensorflow/core/ops/map_ops.cc b/tensorflow/core/ops/map_ops.cc
index f520751..d54ef54 100644
--- a/tensorflow/core/ops/map_ops.cc
+++ b/tensorflow/core/ops/map_ops.cc
@@ -63,16 +63,25 @@
     .Attr("key_dtype: type")
     .Attr("value_dtype: type")
     .SetShapeFn([](shape_inference::InferenceContext* c) {
-      c->set_output(0, c->Scalar());        // output map
+      c->set_output(0, c->Scalar());  // output map
       return Status::OK();
     });
 
 REGISTER_OP("TensorMapHasKey")
     .Input("input_handle: variant")
-    .Input("key: element_dtype")
+    .Input("key: key_dtype")
     .Output("has_key: bool")
-    .Attr("element_dtype: type")
+    .Attr("key_dtype: type")
     .SetShapeFn(shape_inference::ScalarShape);
 
+REGISTER_OP("TensorMapStackKeys")
+    .Input("input_handle: variant")
+    .Output("keys: key_dtype")
+    .Attr("key_dtype: type")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->UnknownShape());  // output keys
+      return Status::OK();
+    });
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index cbf1ef5..3afdb67 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1026,7 +1026,7 @@
     .Input("reduction_indices: Tidx")
     .Output("output: T")
     .Attr("keep_dims: bool = false")
-    .Attr("T: numbertype")
+    .Attr("T: {realnumbertype, quantizedtype}")
     .Attr("Tidx: {int32, int64} = DT_INT32")
     .SetShapeFn(shape_inference::ReductionShape);
 
@@ -1035,7 +1035,7 @@
     .Input("reduction_indices: Tidx")
     .Output("output: T")
     .Attr("keep_dims: bool = false")
-    .Attr("T: numbertype")
+    .Attr("T: {realnumbertype, quantizedtype}")
     .Attr("Tidx: {int32, int64} = DT_INT32")
     .SetShapeFn(shape_inference::ReductionShape);
 
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index a339e53..b34bb41 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -252,7 +252,9 @@
     .Attr("is_training: bool = true")
     .SetShapeFn(shape_inference::FusedBatchNormExShape)
     .Doc(R"doc(
-*NOTE*: Do not invoke this operator directly in Python. Grappler is
+Internal FusedBatchNorm operation: reserved for internal use.
+
+Do not invoke this operator directly in Python. A fusion optimization is
 expected to create these operators.
 )doc");
 
@@ -402,6 +404,8 @@
     .Attr("fused_ops: list(string) = []")
     // Attributes for the FusedBatchNorm ------------------------------------ //
     .Attr("epsilon: float = 0.0001")
+    // Attributes for the LeakyRelu ----------------------------------------- //
+    .Attr("leakyrelu_alpha: float = 0.2")
     // ---------------------------------------------------------------------- //
     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
     .Doc(R"doc(
@@ -631,7 +635,10 @@
     .Attr("fused_ops: list(string) = []")
     // Attributes for the FusedBatchNorm ------------------------------------ //
     .Attr("epsilon: float = 0.0001")
+    // Attributes for the LeakyRelu ----------------------------------------- //
+    .Attr("leakyrelu_alpha: float = 0.2")
     // ---------------------------------------------------------------------- //
+
     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
 
 // --------------------------------------------------------------------------
@@ -1695,7 +1702,7 @@
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("_MklEagerConv2D")
+REGISTER_OP("_MklNativeConv2D")
     .Input("input: T")
     .Input("filter: T")
     .Output("output: T")
@@ -1845,7 +1852,7 @@
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("_MklEagerConv2DBackpropFilter")
+REGISTER_OP("_MklNativeConv2DBackpropFilter")
     .Input("input: T")
     .Input("filter_sizes: int32")
     .Input("out_backprop: T")
@@ -2006,7 +2013,7 @@
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("_MklEagerConv2DBackpropInput")
+REGISTER_OP("_MklNativeConv2DBackpropInput")
     .Input("input_sizes: int32")
     .Input("filter: T")
     .Input("out_backprop: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 50b54c9..cd9a234 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -23634,17 +23634,17 @@
         type: DT_UINT8
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
         type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
         type: DT_BFLOAT16
         type: DT_UINT16
-        type: DT_COMPLEX128
         type: DT_HALF
         type: DT_UINT32
         type: DT_UINT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
       }
     }
   }
@@ -24792,17 +24792,17 @@
         type: DT_UINT8
         type: DT_INT16
         type: DT_INT8
-        type: DT_COMPLEX64
         type: DT_INT64
-        type: DT_QINT8
-        type: DT_QUINT8
-        type: DT_QINT32
         type: DT_BFLOAT16
         type: DT_UINT16
-        type: DT_COMPLEX128
         type: DT_HALF
         type: DT_UINT32
         type: DT_UINT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
       }
     }
   }
@@ -53190,10 +53190,6 @@
     name: "output_handle"
     type: DT_VARIANT
   }
-  output_arg {
-    name: "value"
-    type_attr: "value_dtype"
-  }
   attr {
     name: "key_dtype"
     type: "type"
@@ -53211,14 +53207,14 @@
   }
   input_arg {
     name: "key"
-    type_attr: "element_dtype"
+    type_attr: "key_dtype"
   }
   output_arg {
     name: "has_key"
     type: DT_BOOL
   }
   attr {
-    name: "element_dtype"
+    name: "key_dtype"
     type: "type"
   }
 }
@@ -53284,6 +53280,21 @@
   }
 }
 op {
+  name: "TensorMapStackKeys"
+  input_arg {
+    name: "input_handle"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "keys"
+    type_attr: "key_dtype"
+  }
+  attr {
+    name: "key_dtype"
+    type: "type"
+  }
+}
+op {
   name: "TensorScatterAdd"
   input_arg {
     name: "tensor"
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 5d6f74f..6b35a45 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -54,6 +54,9 @@
 # buildifier: disable=same-origin-load
 load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
 
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+
 package(
     default_visibility = ["//tensorflow:__subpackages__"],
     licenses = ["notice"],  # Apache 2.0
@@ -127,6 +130,7 @@
 cc_library(
     name = "bfloat16",
     hdrs = ["bfloat16.h"],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         ":byte_order",
         "//third_party/eigen3",
@@ -145,6 +149,7 @@
 cc_library(
     name = "byte_order",
     hdrs = ["byte_order.h"],
+    compatible_with = get_compatible_with_portable(),
 )
 
 cc_library(
@@ -168,6 +173,7 @@
 cc_library(
     name = "cord",
     hdrs = ["cord.h"],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         ":platform",
     ] + tf_platform_deps("cord"),
@@ -440,6 +446,7 @@
 cc_library(
     name = "platform",
     hdrs = ["platform.h"],
+    compatible_with = get_compatible_with_portable(),
 )
 
 cc_library(
@@ -772,6 +779,7 @@
         "ctstring_internal.h",
         "tstring.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         ":cord",
         "@com_google_absl//absl/strings",
@@ -789,6 +797,7 @@
 cc_library(
     name = "types",
     hdrs = ["types.h"],
+    compatible_with = get_compatible_with_portable(),
     # TODO(b/161569340): Short-term fix. Remove this visibility rule.
     visibility = [
         "//tensorflow:__subpackages__",
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index ec28309..4b934fb 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -391,8 +391,8 @@
     size = "small",
     srcs = ["oauth_client_test.cc"],
     data = [
-        "testdata/service_account_credentials.json",
-        "testdata/service_account_public_key.txt",
+        "//tensorflow/core/platform/cloud/testdata:service_account_credentials",
+        "//tensorflow/core/platform/cloud/testdata:service_account_public_key",
     ],
     deps = [
         ":http_request_fake",
@@ -414,8 +414,8 @@
     size = "small",
     srcs = ["google_auth_provider_test.cc"],
     data = [
-        "testdata/application_default_credentials.json",
-        "testdata/service_account_credentials.json",
+        "//tensorflow/core/platform/cloud/testdata:application_default_credentials",
+        "//tensorflow/core/platform/cloud/testdata:service_account_credentials",
     ],
     deps = [
         ":google_auth_provider",
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index f0d2138..59eb610 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -30,7 +30,7 @@
 #include <io.h>  // for _mktemp
 #endif
 #include "absl/base/macros.h"
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/platform/cloud/curl_http_request.h"
 #include "tensorflow/core/platform/cloud/file_block_cache.h"
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc
index e8546ca..57240fa 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider.cc
@@ -24,7 +24,7 @@
 #include <utility>
 
 #include "absl/strings/match.h"
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/core/platform/base64.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
diff --git a/tensorflow/core/platform/cloud/oauth_client.h b/tensorflow/core/platform/cloud/oauth_client.h
index ed8bf25..97af3ec 100644
--- a/tensorflow/core/platform/cloud/oauth_client.h
+++ b/tensorflow/core/platform/cloud/oauth_client.h
@@ -18,7 +18,7 @@
 
 #include <memory>
 
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/core/platform/cloud/http_request.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/status.h"
diff --git a/tensorflow/core/platform/cloud/testdata/BUILD b/tensorflow/core/platform/cloud/testdata/BUILD
new file mode 100644
index 0000000..c7f1b8e
--- /dev/null
+++ b/tensorflow/core/platform/cloud/testdata/BUILD
@@ -0,0 +1,29 @@
+# Cloud test data files.
+
+package(
+    licenses = ["notice"],  # Apache 2.0
+)
+
+filegroup(
+    name = "application_default_credentials",
+    srcs = [
+        "application_default_credentials.json",
+    ],
+    visibility = ["//tensorflow/core/platform/cloud:__pkg__"],
+)
+
+filegroup(
+    name = "service_account_credentials",
+    srcs = [
+        "service_account_credentials.json",
+    ],
+    visibility = ["//tensorflow/core/platform/cloud:__pkg__"],
+)
+
+filegroup(
+    name = "service_account_public_key",
+    srcs = [
+        "service_account_public_key.txt",
+    ],
+    visibility = ["//tensorflow/core/platform/cloud:__pkg__"],
+)
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index 7b71679..308d8a0 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -111,6 +111,13 @@
   Status NewRandomAccessFile(const std::string& fname,
                              std::unique_ptr<RandomAccessFile>* result);
 
+  Status NewRandomAccessFile(const std::string& fname, TransactionToken* token,
+                             std::unique_ptr<RandomAccessFile>* result) {
+    // We duplicate these methods due to Google internal coding style prevents
+    // virtual functions with default arguments. See PR #41615.
+    return Status::OK();
+  }
+
   /// \brief Creates an object that writes to a new file with the specified
   /// name.
   ///
@@ -127,6 +134,11 @@
   Status NewWritableFile(const std::string& fname,
                          std::unique_ptr<WritableFile>* result);
 
+  Status NewWritableFile(const std::string& fname, TransactionToken* token,
+                         std::unique_ptr<WritableFile>* result) {
+    return Status::OK();
+  }
+
   /// \brief Creates an object that either appends to an existing file, or
   /// writes to a new file (if the file does not exist to begin with).
   ///
@@ -142,6 +154,10 @@
   Status NewAppendableFile(const std::string& fname,
                            std::unique_ptr<WritableFile>* result);
 
+  Status NewAppendableFile(const std::string& fname, TransactionToken* token,
+                           std::unique_ptr<WritableFile>* result) {
+    return Status::OK();
+  }
   /// \brief Creates a readonly region of memory with the file context.
   ///
   /// On success, it returns a pointer to read-only memory region
@@ -156,21 +172,41 @@
   Status NewReadOnlyMemoryRegionFromFile(
       const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result);
 
+  Status NewReadOnlyMemoryRegionFromFile(
+      const std::string& fname, TransactionToken* token,
+      std::unique_ptr<ReadOnlyMemoryRegion>* result) {
+    return Status::OK();
+  }
+
   /// Returns OK if the named path exists and NOT_FOUND otherwise.
   Status FileExists(const std::string& fname);
 
+  Status FileExists(const std::string& fname, TransactionToken* token) {
+    return Status::OK();
+  }
+
   /// Returns true if all the listed files exist, false otherwise.
   /// if status is not null, populate the vector with a detailed status
   /// for each file.
   bool FilesExist(const std::vector<string>& files,
                   std::vector<Status>* status);
 
+  bool FilesExist(const std::vector<string>& files, TransactionToken* token,
+                  std::vector<Status>* status) {
+    return true;
+  }
+
   /// \brief Stores in *result the names of the children of the specified
   /// directory. The names are relative to "dir".
   ///
   /// Original contents of *results are dropped.
   Status GetChildren(const std::string& dir, std::vector<string>* result);
 
+  Status GetChildren(const std::string& dir, TransactionToken* token,
+                     std::vector<string>* result) {
+    return Status::OK();
+  }
+
   /// \brief Returns true if the path matches the given pattern. The wildcards
   /// allowed in pattern are described in FileSystem::GetMatchingPaths.
   virtual bool MatchPath(const std::string& path,
@@ -183,9 +219,18 @@
   virtual Status GetMatchingPaths(const std::string& pattern,
                                   std::vector<string>* results);
 
+  Status GetMatchingPaths(const std::string& pattern, TransactionToken* token,
+                          std::vector<string>* results) {
+    return Status::OK();
+  }
+
   /// Deletes the named file.
   Status DeleteFile(const std::string& fname);
 
+  Status DeleteFile(const std::string& fname, TransactionToken* token) {
+    return Status::OK();
+  }
+
   /// \brief Deletes the specified directory and all subdirectories and files
   /// underneath it. This is accomplished by traversing the directory tree
   /// rooted at dirname and deleting entries as they are encountered.
@@ -213,6 +258,11 @@
   Status DeleteRecursively(const std::string& dirname, int64* undeleted_files,
                            int64* undeleted_dirs);
 
+  Status DeleteRecursively(const std::string& dirname, TransactionToken* token,
+                           int64* undeleted_files, int64* undeleted_dirs) {
+    return Status::OK();
+  }
+
   /// \brief Creates the specified directory and all the necessary
   /// subdirectories. Typical return codes.
   ///  * OK - successfully created the directory and sub directories, even if
@@ -220,18 +270,35 @@
   ///  * PERMISSION_DENIED - dirname or some subdirectory is not writable.
   Status RecursivelyCreateDir(const std::string& dirname);
 
+  Status RecursivelyCreateDir(const std::string& dirname,
+                              TransactionToken* token) {
+    return Status::OK();
+  }
   /// \brief Creates the specified directory. Typical return codes
   ///  * OK - successfully created the directory.
   ///  * ALREADY_EXISTS - directory already exists.
   ///  * PERMISSION_DENIED - dirname is not writable.
   Status CreateDir(const std::string& dirname);
 
+  Status CreateDir(const std::string& dirname, TransactionToken* token) {
+    return Status::OK();
+  }
+
   /// Deletes the specified directory.
   Status DeleteDir(const std::string& dirname);
 
+  Status DeleteDir(const std::string& dirname, TransactionToken* token) {
+    return Status::OK();
+  }
+
   /// Obtains statistics for the given path.
   Status Stat(const std::string& fname, FileStatistics* stat);
 
+  Status Stat(const std::string& fname, TransactionToken* token,
+              FileStatistics* stat) {
+    return Status::OK();
+  }
+
   /// \brief Returns whether the given path is a directory or not.
   /// Typical return codes (not guaranteed exhaustive):
   ///  * OK - The path exists and is a directory.
@@ -256,13 +323,59 @@
   /// Stores the size of `fname` in `*file_size`.
   Status GetFileSize(const std::string& fname, uint64* file_size);
 
+  Status GetFileSize(const std::string& fname, TransactionToken* token,
+                     uint64* file_size) {
+    return Status::OK();
+  }
+
   /// \brief Renames file src to target. If target already exists, it will be
   /// replaced.
   Status RenameFile(const std::string& src, const std::string& target);
 
+  Status RenameFile(const std::string& src, const std::string& target,
+                    TransactionToken* token) {
+    return Status::OK();
+  }
+
   /// \brief Copy the src to target.
   Status CopyFile(const std::string& src, const std::string& target);
 
+  Status CopyFile(const std::string& src, const std::string& target,
+                  TransactionToken* token) {
+    return Status::OK();
+  }
+
+  /// \brief starts a new transaction on the filesystem that handles filename
+  Status StartTransaction(const std::string& filename,
+                          TransactionToken** token) {
+    token = nullptr;
+    return Status::OK();
+  }
+
+  /// \brief Adds `path` to transaction in `token` if token belongs to
+  /// filesystem that handles the path.
+  Status AddToTransaction(const std::string& path, TransactionToken* token) {
+    return Status::OK();
+  }
+
+  /// \brief Get token for `path` or start a new transaction and add `path` to
+  /// it.
+  Status GetTokenOrStartTransaction(const std::string& path,
+                                    TransactionToken** token) {
+    *token = nullptr;
+    return Status::OK();
+  }
+
+  /// \brief Returns the transaction for `path` or nullptr in `token`
+  Status GetTransactionForPath(const std::string& path,
+                               TransactionToken** token) {
+    token = nullptr;
+    return Status::OK();
+  }
+
+  /// \brief Finalizes the transaction
+  Status EndTransaction(TransactionToken* token) { return Status::OK(); }
+
   /// \brief Returns the absolute path of the current executable. It resolves
   /// symlinks if there is any.
   std::string GetExecutablePath();
diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h
index a38c57d..4f8e49d 100644
--- a/tensorflow/core/platform/macros.h
+++ b/tensorflow/core/platform/macros.h
@@ -74,6 +74,25 @@
 #define TF_HAS_BUILTIN(x) 0
 #endif
 
+// C++11-style attributes (N2761)
+#if defined(__has_cpp_attribute)
+// Safely checks if an attribute is supported. Equivalent to
+// ABSL_HAVE_CPP_ATTRIBUTE.
+#define TF_HAS_CPP_ATTRIBUTE(n) __has_cpp_attribute(n)
+#else
+#define TF_HAS_CPP_ATTRIBUTE(n) 0
+#endif
+
+// [[clang::annotate("x")]] allows attaching custom strings (e.g. "x") to
+// declarations (variables, functions, fields, etc.) for use by tools. They are
+// represented in the Clang AST (as AnnotateAttr nodes) and in LLVM IR, but not
+// in final output.
+#if TF_HAS_CPP_ATTRIBUTE(clang::annotate)
+#define TF_ATTRIBUTE_ANNOTATE(str) [[clang::annotate(str)]]
+#else
+#define TF_ATTRIBUTE_ANNOTATE(str)
+#endif
+
 // Compilers can be told that a certain branch is not likely to be taken
 // (for instance, a CHECK failure), and use that information in static
 // analysis. Giving it this information can help it optimize for the
diff --git a/tensorflow/core/platform/ram_file_system.h b/tensorflow/core/platform/ram_file_system.h
index 407bcb3..ce6d054 100644
--- a/tensorflow/core/platform/ram_file_system.h
+++ b/tensorflow/core/platform/ram_file_system.h
@@ -177,7 +177,7 @@
               FileStatistics* stat) override {
     mutex_lock m(mu_);
     auto it = fs_.lower_bound(fname);
-    if (it == fs_.end()) {
+    if (it == fs_.end() || !absl::StartsWith(it->first, fname)) {
       return errors::NotFound("");
     }
 
diff --git a/tensorflow/core/platform/ram_file_system_test.py b/tensorflow/core/platform/ram_file_system_test.py
index 0f4f47e..960765d 100644
--- a/tensorflow/core/platform/ram_file_system_test.py
+++ b/tensorflow/core/platform/ram_file_system_test.py
@@ -21,6 +21,7 @@
 
 import numpy as np
 
+from tensorflow.python.eager import def_function
 from tensorflow.python.estimator.estimator import Estimator
 from tensorflow.python.estimator.model_fn import EstimatorSpec
 from tensorflow.python.estimator.run_config import RunConfig
@@ -28,9 +29,11 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_util
 from tensorflow.python.layers import core as core_layers
+from tensorflow.python.module import module
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
+from tensorflow.python.saved_model import saved_model
 from tensorflow.python.training import adam
 from tensorflow.python.training import training_util
 
@@ -82,6 +85,17 @@
     matches = ['ram://c/b/%d.txt' % i for i in range(10)]
     self.assertEqual(gfile.Glob('ram://c/b/*'), matches)
 
+  def test_file_exists(self):
+    with gfile.GFile('ram://exists/a/b/c.txt', 'w') as f:
+      f.write('')
+    self.assertTrue(gfile.Exists('ram://exists/a'))
+    self.assertTrue(gfile.Exists('ram://exists/a/b'))
+    self.assertTrue(gfile.Exists('ram://exists/a/b/c.txt'))
+
+    self.assertFalse(gfile.Exists('ram://exists/b'))
+    self.assertFalse(gfile.Exists('ram://exists/a/c'))
+    self.assertFalse(gfile.Exists('ram://exists/a/b/k'))
+
   def test_estimator(self):
 
     def model_fn(features, labels, mode, params):
@@ -114,6 +128,18 @@
     estimator.train(input_fn=input_fn, steps=10)
     estimator.train(input_fn=input_fn, steps=10)
 
+  def test_savedmodel(self):
+    class MyModule(module.Module):
+
+      @def_function.function(input_signature=[])
+      def foo(self):
+        return constant_op.constant([1])
+
+    saved_model.save(MyModule(), 'ram://my_module')
+
+    loaded = saved_model.load('ram://my_module')
+    self.assertAllEqual(loaded.foo(), [1])
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/core/platform/tf32_utils.cc b/tensorflow/core/platform/tf32_utils.cc
index d2f40ea..21059b9 100644
--- a/tensorflow/core/platform/tf32_utils.cc
+++ b/tensorflow/core/platform/tf32_utils.cc
@@ -20,8 +20,8 @@
 namespace tensorflow {
 
 // Whether TensorFloat-32 should be used where supported.
-// TODO(nluehr): Maybe enable by default after TF32 Ampere testing.
-static std::atomic<bool> tf32_allowed{false};
+// TODO(reedwm): Change word "allow" to "enable" in all TensorFloat-32 functions
+static std::atomic<bool> tf32_allowed{true};
 
 void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; }
 
diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
index 425bf00..4cf81f4 100644
--- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
+++ b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
@@ -36,6 +36,9 @@
   DCHECK(dst != nullptr);
   DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id());
   DCHECK_EQ(src.name(), dst->name());
+  if (dst->long_name().empty()) {
+    dst->set_long_name(src.long_name());
+  }
   if (dst->category().empty()) {
     dst->set_category(src.category());
   }
diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
index 276181d..8f58b7b 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
@@ -234,7 +234,8 @@
   uint64 outside_compilation_device_op_time_ps = 0;
   for (const OpMetrics& metrics :
        op_stats.device_op_metrics_db().metrics_db()) {
-    if (!IsOutsideCompilationOp(metrics.provenance(), metrics.name())) continue;
+    if (!IsOutsideCompilationOp(metrics.provenance(), metrics.long_name()))
+      continue;
     outside_compilation_device_op_time_ps += metrics.self_time_ps();
   }
   uint64 num_total_tf_ops = num_host_tf_ops + num_device_tf_ops;
diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
index a5e127a..4835487 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
@@ -29,6 +29,10 @@
 namespace profiler {
 namespace {
 
+// The maximum number of Tensorflow Ops displayed on Tensorflow Stats page.
+// 500 device side ops and 500 host side ops.
+const int kMaxNumOfOps = 500;
+
 TfStatsRecord ConvertOpMetricsToTfStatsRecord(
     bool on_device, const OpMetrics& metrics,
     double ridge_point_operational_intensity) {
@@ -60,7 +64,8 @@
     total_device_time_ps -= IdleTimePs(device_tf_metrics_db);
   }
   double total_device_time_us = PicosToMicros(total_device_time_ps);
-  for (const OpMetrics* metrics : SortedOpMetricsDb(device_tf_metrics_db)) {
+  for (const OpMetrics* metrics :
+       SortedOpMetricsDb(device_tf_metrics_db, kMaxNumOfOps)) {
     if (exclude_idle && IsIdleOp(*metrics)) continue;
     TfStatsRecord* record = tf_stats_table.add_tf_stats_record();
     *record = ConvertOpMetricsToTfStatsRecord(
@@ -84,8 +89,8 @@
     total_host_time_ps -= IdleTimePs(host_tf_metrics_db);
   }
   double total_host_time_us = PicosToMicros(total_host_time_ps);
-  for (const OpMetrics* metrics :
-       tensorflow::profiler::SortedOpMetricsDb(host_tf_metrics_db)) {
+  for (const OpMetrics* metrics : tensorflow::profiler::SortedOpMetricsDb(
+           host_tf_metrics_db, kMaxNumOfOps)) {
     if (exclude_idle && IsIdleOp(*metrics)) continue;
     TfStatsRecord* record = tf_stats_table.add_tf_stats_record();
     *record = ConvertOpMetricsToTfStatsRecord(
diff --git a/tensorflow/core/profiler/convert/trace_events_to_json.cc b/tensorflow/core/profiler/convert/trace_events_to_json.cc
index ba3e451..ad40292 100644
--- a/tensorflow/core/profiler/convert/trace_events_to_json.cc
+++ b/tensorflow/core/profiler/convert/trace_events_to_json.cc
@@ -21,7 +21,7 @@
 
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
 
diff --git a/tensorflow/core/profiler/convert/trace_events_to_json_test.cc b/tensorflow/core/profiler/convert/trace_events_to_json_test.cc
index dc985f2..bf08a19 100644
--- a/tensorflow/core/profiler/convert/trace_events_to_json_test.cc
+++ b/tensorflow/core/profiler/convert/trace_events_to_json_test.cc
@@ -15,7 +15,7 @@
 
 #include "tensorflow/core/profiler/convert/trace_events_to_json.h"
 
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
index 4233c5f..a6bc2a5 100644
--- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
@@ -58,7 +58,7 @@
 
 PythonTracer::~PythonTracer() {
   Stop().IgnoreError();
-  PythonHooks::GetSingleton()->Finalize();
+  PythonHooks::GetSingleton()->Finalize(nullptr);
 }
 
 Status PythonTracer::Start() {
@@ -76,7 +76,7 @@
     return errors::Internal("TraceMeRecorder not started");
   }
   VLOG(1) << __FUNCTION__;
-  PythonHooks::GetSingleton()->Stop(options_);
+  PythonHooks::GetSingleton()->Stop();
   recording_ = false;
   return Status::OK();
 }
@@ -87,17 +87,12 @@
   // in the wrong threads.
   // We had assumed HostTracer::Stop is called when ProfilerSession try to
   // serialize PythonTracer.
-  PythonHooks::GetSingleton()->Finalize();
+  PythonHooks::GetSingleton()->Finalize(nullptr);
   return Status::OK();
 }
 
 Status PythonTracer::CollectData(XSpace* space) {
-  // This ProfilerInterface rely on HostTracer to serialize its trace.
-  // Make sure unpaired traceme don't get recorded, because it will end up
-  // in the wrong threads.
-  // We had assumed HostTracer::Stop is called when ProfilerSession try to
-  // serialize PythonTracer.
-  PythonHooks::GetSingleton()->Finalize();
+  PythonHooks::GetSingleton()->Finalize(space);
   return Status::OK();
 }
 
@@ -107,8 +102,7 @@
 std::unique_ptr<ProfilerInterface> CreatePythonTracer(
     const ProfileOptions& options) {
   PythonHooksOptions pyhooks_options;
-  pyhooks_options.enable_trace_python_function =
-      options.python_tracer_level() && options.host_tracer_level();
+  pyhooks_options.enable_trace_python_function = options.python_tracer_level();
   pyhooks_options.enable_python_traceme = options.host_tracer_level() != 0;
   return absl::make_unique<PythonTracer>(pyhooks_options);
 }
diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.h b/tensorflow/core/profiler/internal/tfprof_timeline.h
index 834e3c9..fb9ff80 100644
--- a/tensorflow/core/profiler/internal/tfprof_timeline.h
+++ b/tensorflow/core/profiler/internal/tfprof_timeline.h
@@ -17,7 +17,7 @@
 #define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_
 
 #include "absl/strings/str_cat.h"
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/step_stats.pb.h"
 #include "tensorflow/core/profiler/internal/tfprof_node_show.h"
diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc
index ee6eb55..8a20ef1 100644
--- a/tensorflow/core/profiler/lib/profiler_session.cc
+++ b/tensorflow/core/profiler/lib/profiler_session.cc
@@ -85,13 +85,28 @@
   // 1. Merge plane of host events with plane of CUPTI driver api.
   const profiler::XPlane* cupti_driver_api_plane =
       profiler::FindPlaneWithName(*space, profiler::kCuptiDriverApiPlaneName);
-  if (cupti_driver_api_plane) {
+  const profiler::XPlane* python_tracer_plane =
+      profiler::FindPlaneWithName(*space, profiler::kPythonTracerPlaneName);
+  if (cupti_driver_api_plane || python_tracer_plane) {
     profiler::XPlane* host_plane = profiler::FindOrAddMutablePlaneWithName(
         space, profiler::kHostThreadsPlaneName);
-    profiler::MergePlanes(*cupti_driver_api_plane, host_plane);
+    if (cupti_driver_api_plane) {
+      profiler::MergePlanes(*cupti_driver_api_plane, host_plane);
+    }
+    if (python_tracer_plane) {
+      profiler::MergePlanes(*python_tracer_plane, host_plane);
+    }
     profiler::SortXLinesBy(host_plane, profiler::XLinesComparatorByName());
-    profiler::RemovePlaneWithName(space, profiler::kCuptiDriverApiPlaneName);
+    // NOTE: RemovePlaneWithName might invalidate plane pointers. so do these
+    // at the last step.
+    if (cupti_driver_api_plane) {
+      profiler::RemovePlaneWithName(space, profiler::kCuptiDriverApiPlaneName);
+    }
+    if (python_tracer_plane) {
+      profiler::RemovePlaneWithName(space, profiler::kPythonTracerPlaneName);
+    }
   }
+
   // 2. Normalize all timestamps by shifting timeline to profiling start time.
   // NOTE: this have to be done before sorting XSpace due to timestamp overflow.
   profiler::NormalizeTimestamps(space, start_time_ns_);
diff --git a/tensorflow/core/profiler/protobuf/op_metrics.proto b/tensorflow/core/profiler/protobuf/op_metrics.proto
index af38795..670ebd5 100644
--- a/tensorflow/core/profiler/protobuf/op_metrics.proto
+++ b/tensorflow/core/profiler/protobuf/op_metrics.proto
@@ -26,12 +26,14 @@
 }
 
 // Metrics for an operation (accumulated over all occurrences).
-// Next ID: 20
+// Next ID: 21
 message OpMetrics {
   // HLO module id. 0 for TF ops.
   uint64 hlo_module_id = 13;
   // Name of this op.
   string name = 6;
+  // Long name of this op (e.g., HLO expression).
+  string long_name = 20;
   // Category of this op.
   string category = 11;
   // Provenance of this op (e.g., if HLO op, original TF op).
diff --git a/tensorflow/core/profiler/rpc/client/save_profile.cc b/tensorflow/core/profiler/rpc/client/save_profile.cc
index 81f9490..20ff496 100644
--- a/tensorflow/core/profiler/rpc/client/save_profile.cc
+++ b/tensorflow/core/profiler/rpc/client/save_profile.cc
@@ -51,13 +51,14 @@
 const absl::string_view kPathSep = "/";
 #endif
 
-string ProfilerJoinPathImpl(std::initializer_list<absl::string_view> paths) {
-  string result;
+std::string ProfilerJoinPathImpl(
+    std::initializer_list<absl::string_view> paths) {
+  std::string result;
   for (absl::string_view path : paths) {
     if (path.empty()) continue;
 
     if (result.empty()) {
-      result = string(path);
+      result = std::string(path);
       continue;
     }
 
@@ -75,7 +76,7 @@
 // A local duplication of ::tensorflow::io::JoinPath that supports windows.
 // TODO(b/150699701): revert to use ::tensorflow::io::JoinPath when fixed.
 template <typename... T>
-string ProfilerJoinPath(const T&... args) {
+std::string ProfilerJoinPath(const T&... args) {
   return ProfilerJoinPathImpl({args...});
 }
 
@@ -86,8 +87,8 @@
                     const ProfileToolData& tool, std::ostream* os) {
   // Don't save the intermediate results for combining the per host tool data.
   if (absl::EndsWith(tool.name(), kTfStatsHelperSuffix)) return Status::OK();
-  string host_prefix = host.empty() ? "" : absl::StrCat(host, ".");
-  string path =
+  std::string host_prefix = host.empty() ? "" : absl::StrCat(host, ".");
+  std::string path =
       ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool.name()));
   TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data()));
   if (os) {
@@ -97,7 +98,8 @@
   return Status::OK();
 }
 
-Status WriteGzippedDataToFile(const string& filepath, const string& data) {
+Status WriteGzippedDataToFile(const std::string& filepath,
+                              const std::string& data) {
   std::unique_ptr<WritableFile> file;
   TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filepath, &file));
   io::ZlibCompressionOptions options = io::ZlibCompressionOptions::GZIP();
@@ -110,8 +112,9 @@
   return Status::OK();
 }
 
-Status GetOrCreateRunDir(const string& repository_root, const string& run,
-                         string* run_dir, std::ostream* os) {
+Status GetOrCreateRunDir(const std::string& repository_root,
+                         const std::string& run, std::string* run_dir,
+                         std::ostream* os) {
   // Dumps profile data to <repository_root>/<run>/.
   *run_dir = ProfilerJoinPath(repository_root, run);
   *os << "Creating directory: " << *run_dir;
@@ -120,21 +123,21 @@
 }
 }  // namespace
 
-string GetTensorBoardProfilePluginDir(const string& logdir) {
+std::string GetTensorBoardProfilePluginDir(const std::string& logdir) {
   constexpr char kPluginName[] = "plugins";
   constexpr char kProfileName[] = "profile";
   return ProfilerJoinPath(logdir, kPluginName, kProfileName);
 }
 
-Status MaybeCreateEmptyEventFile(const string& logdir) {
+Status MaybeCreateEmptyEventFile(const std::string& logdir) {
   // Suffix for an empty event file.  it should be kept in sync with
   // _EVENT_FILE_SUFFIX in tensorflow/python/eager/profiler.py.
   constexpr char kProfileEmptySuffix[] = ".profile-empty";
   TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(logdir));
 
-  std::vector<string> children;
+  std::vector<std::string> children;
   TF_RETURN_IF_ERROR(Env::Default()->GetChildren(logdir, &children));
-  for (const string& child : children) {
+  for (const std::string& child : children) {
     if (absl::EndsWith(child, kProfileEmptySuffix)) {
       return Status::OK();
     }
@@ -143,10 +146,10 @@
   return event_writer.InitWithSuffix(kProfileEmptySuffix);
 }
 
-Status SaveProfile(const string& repository_root, const string& run,
-                   const string& host, const ProfileResponse& response,
+Status SaveProfile(const std::string& repository_root, const std::string& run,
+                   const std::string& host, const ProfileResponse& response,
                    std::ostream* os) {
-  string run_dir;
+  std::string run_dir;
   TF_RETURN_IF_ERROR(GetOrCreateRunDir(repository_root, run, &run_dir, os));
   for (const auto& tool_data : response.tool_data()) {
     TF_RETURN_IF_ERROR(DumpToolData(run_dir, host, tool_data, os));
@@ -154,22 +157,24 @@
   return Status::OK();
 }
 
-Status SaveGzippedToolData(const string& repository_root, const string& run,
-                           const string& host, const string& tool_name,
-                           const string& data) {
-  string run_dir;
+Status SaveGzippedToolData(const std::string& repository_root,
+                           const std::string& run, const std::string& host,
+                           const std::string& tool_name,
+                           const std::string& data) {
+  std::string run_dir;
   std::stringstream ss;
   Status status = GetOrCreateRunDir(repository_root, run, &run_dir, &ss);
   LOG(INFO) << ss.str();
   TF_RETURN_IF_ERROR(status);
-  string host_prefix = host.empty() ? "" : absl::StrCat(host, ".");
-  string path = ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool_name));
+  std::string host_prefix = host.empty() ? "" : absl::StrCat(host, ".");
+  std::string path =
+      ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool_name));
   TF_RETURN_IF_ERROR(WriteGzippedDataToFile(path, data));
   LOG(INFO) << "Dumped gzipped tool data for " << tool_name << " to " << path;
   return Status::OK();
 }
 
-string GetCurrentTimeStampAsString() {
+std::string GetCurrentTimeStampAsString() {
   return absl::FormatTime("%E4Y_%m_%d_%H_%M_%S", absl::Now(),
                           absl::LocalTimeZone());
 }
diff --git a/tensorflow/core/profiler/rpc/client/save_profile.h b/tensorflow/core/profiler/rpc/client/save_profile.h
index c155502..9c15ef2 100644
--- a/tensorflow/core/profiler/rpc/client/save_profile.h
+++ b/tensorflow/core/profiler/rpc/client/save_profile.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_SAVE_PROFILE_H_
 
 #include <ostream>
+#include <string>
 
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/types.h"
@@ -25,27 +26,28 @@
 namespace tensorflow {
 namespace profiler {
 
-string GetCurrentTimeStampAsString();
+std::string GetCurrentTimeStampAsString();
 
 // Returns the profile plugin directory given a logdir to TensorBoard.
-string GetTensorBoardProfilePluginDir(const string& logdir);
+std::string GetTensorBoardProfilePluginDir(const std::string& logdir);
 
 // Creates an empty event file if not already exists, which indicates that we
 // have a plugins/profile/ directory in the current logdir.
-Status MaybeCreateEmptyEventFile(const string& logdir);
+Status MaybeCreateEmptyEventFile(const std::string& logdir);
 
 // Saves all profiling tool data in a profile to <repository_root>/<run>/.
 // This writes user-facing log messages to `os`.
 // Note: this function creates a directory even when all fields in
 // ProfileResponse are unset/empty.
-Status SaveProfile(const string& repository_root, const string& run,
-                   const string& host, const ProfileResponse& response,
+Status SaveProfile(const std::string& repository_root, const std::string& run,
+                   const std::string& host, const ProfileResponse& response,
                    std::ostream* os);
 
 // Gzip the data and save to <repository_root>/<run>/.
-Status SaveGzippedToolData(const string& repository_root, const string& run,
-                           const string& host, const string& tool_name,
-                           const string& data);
+Status SaveGzippedToolData(const std::string& repository_root,
+                           const std::string& run, const std::string& host,
+                           const std::string& tool_name,
+                           const std::string& data);
 
 }  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc
index 966a94a..cfff3fc 100644
--- a/tensorflow/core/profiler/rpc/profiler_server.cc
+++ b/tensorflow/core/profiler/rpc/profiler_server.cc
@@ -27,6 +27,7 @@
 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
 
 namespace tensorflow {
+namespace profiler {
 
 void ProfilerServer::StartProfilerServer(int32 port) {
   std::string server_address = absl::StrCat("[::]:", port);
@@ -54,4 +55,5 @@
   }
 }
 
+}  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/profiler_server.h b/tensorflow/core/profiler/rpc/profiler_server.h
index b7148e7..45680e8 100644
--- a/tensorflow/core/profiler/rpc/profiler_server.h
+++ b/tensorflow/core/profiler/rpc/profiler_server.h
@@ -22,6 +22,7 @@
 #include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
 
 namespace tensorflow {
+namespace profiler {
 
 class ProfilerServer {
  public:
@@ -34,6 +35,7 @@
   std::unique_ptr<::grpc::Server> server_;
 };
 
+}  // namespace profiler
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
index ba46381..8eadd87 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
@@ -34,6 +34,7 @@
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 
 namespace tensorflow {
+namespace profiler {
 namespace {
 
 const absl::string_view kXPlanePb = "xplane.pb";
@@ -115,4 +116,10 @@
   return absl::make_unique<ProfilerServiceImpl>();
 }
 
+}  // namespace profiler
+
+std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService() {
+  return absl::make_unique<profiler::ProfilerServiceImpl>();
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.h b/tensorflow/core/profiler/rpc/profiler_service_impl.h
index 00a850a..3960b33 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.h
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.h
@@ -23,6 +23,11 @@
 
 std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService();
 
+namespace profiler {
+
+std::unique_ptr<grpc::ProfilerService::Service> CreateProfilerService();
+
+}  // namespace profiler
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_
diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD
index 2d3ec1d..d4957ca 100644
--- a/tensorflow/core/profiler/utils/BUILD
+++ b/tensorflow/core/profiler/utils/BUILD
@@ -212,6 +212,7 @@
     visibility = [":friends"],
     deps = [
         ":timespan",
+        ":trace_utils",
         ":xplane_builder",
         ":xplane_visitor",
         "//tensorflow/core:platform_base",
diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc
index 8656682..85367e3 100644
--- a/tensorflow/core/profiler/utils/group_events.cc
+++ b/tensorflow/core/profiler/utils/group_events.cc
@@ -405,12 +405,14 @@
          FindParent(HostEventType::kEagerKernelExecute) != nullptr;
 }
 
-EventNode* EventNode::FindParent(int64 event_type) const {
-  if (parent_) {
-    if (parent_->GetEventVisitor().Type() == event_type) {
-      return parent_;
-    }
-    return parent_->FindParent(event_type);
+const EventNode* EventNode::FindParent(int64 event_type) const {
+  absl::flat_hash_set<const EventNode*> seen;
+  const EventNode* node = this;
+  while (node) {
+    if (seen.contains(node)) break;
+    if (node->GetEventVisitor().Type() == event_type) return node;
+    seen.insert(node);
+    node = node->GetParent();
   }
   return nullptr;
 }
diff --git a/tensorflow/core/profiler/utils/group_events.h b/tensorflow/core/profiler/utils/group_events.h
index e03acf3..44026c8 100644
--- a/tensorflow/core/profiler/utils/group_events.h
+++ b/tensorflow/core/profiler/utils/group_events.h
@@ -89,8 +89,8 @@
 
   bool IsNestedIn(EventNode* parent);
 
-  // Returns the closest parent of the given event type.
-  EventNode* FindParent(int64 event_type) const;
+  // Returns the closest parent (including itself) of the given event type.
+  const EventNode* FindParent(int64 event_type) const;
 
   absl::optional<ContextInfo> GetProducerContext() const {
     return producer_context_;
diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc
index 1f01e00..2e10ae5 100644
--- a/tensorflow/core/profiler/utils/op_utils.cc
+++ b/tensorflow/core/profiler/utils/op_utils.cc
@@ -82,13 +82,9 @@
   op_metrics->set_occurrences(op_metrics->occurrences() + occurrences);
   op_metrics->set_time_ps(op_metrics->time_ps() + time_ps);
   op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_time_ps);
-  op_metrics->set_flops(op_metrics->flops() +
-                        GetCappedPerf(flops * occurrences, self_time_ps,
-                                      peak_tera_flops_per_second_));
-  op_metrics->set_bytes_accessed(
-      op_metrics->bytes_accessed() +
-      GetCappedPerf(bytes_accessed * occurrences, self_time_ps,
-                    peak_hbm_bw_giga_bytes_per_second_ / 1000));
+  op_metrics->set_flops(op_metrics->flops() + flops * occurrences);
+  op_metrics->set_bytes_accessed(op_metrics->bytes_accessed() +
+                                 bytes_accessed * occurrences);
   CombineMemoryAccessedBreakdown(
       memory_accessed_breakdown,
       op_metrics->mutable_memory_accessed_breakdown());
diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc
index 19831a5..4273adf 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.cc
+++ b/tensorflow/core/profiler/utils/xplane_schema.cc
@@ -30,6 +30,7 @@
 const absl::string_view kCuptiDriverApiPlaneName = "/host:CUPTI";
 const absl::string_view kMetadataPlaneName = "/host:metadata";
 const absl::string_view kTFStreamzPlaneName = "/host:tfstreamz";
+const absl::string_view kPythonTracerPlaneName = "/host:python-tracer";
 
 const absl::string_view kStepLineName = "Steps";
 const absl::string_view kTensorFlowNameScopeLineName = "TensorFlow Name Scope";
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index ea36561..e4677c0 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -36,6 +36,8 @@
 ABSL_CONST_INIT extern const absl::string_view kMetadataPlaneName;
 // Name of XPlane that contains kpi related metrics.
 ABSL_CONST_INIT extern const absl::string_view kTFStreamzPlaneName;
+// Name of XPlane that contains events from python tracer.
+ABSL_CONST_INIT extern const absl::string_view kPythonTracerPlaneName;
 
 // Names of XLines that contain ML-level events.
 ABSL_CONST_INIT extern const absl::string_view kStepLineName;
diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc
index 867d131..825469f 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.cc
+++ b/tensorflow/core/profiler/utils/xplane_utils.cc
@@ -40,13 +40,6 @@
   return Timespan(event.offset_ps(), event.duration_ps());
 }
 
-// Functor that compares XEvents of the same XLine for sorting by timespan.
-struct XEventsComparator {
-  bool operator()(const XEvent* a, const XEvent* b) const {
-    return XEventTimespan(*a) < XEventTimespan(*b);
-  }
-};
-
 }  // namespace
 
 const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) {
@@ -144,6 +137,10 @@
                lines->end());
 }
 
+bool XEventsComparator::operator()(const XEvent* a, const XEvent* b) const {
+  return XEventTimespan(*a) < XEventTimespan(*b);
+}
+
 void SortXPlane(XPlane* plane) {
   for (XLine& line : *plane->mutable_lines()) {
     auto& events = *line.mutable_events();
diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h
index ff65f5a..5cd5275 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.h
+++ b/tensorflow/core/profiler/utils/xplane_utils.h
@@ -21,6 +21,7 @@
 #include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/trace_utils.h"
 
 namespace tensorflow {
 namespace profiler {
@@ -75,6 +76,26 @@
 // Sorts each plane of the XSpace.
 void SortXSpace(XSpace* space);
 
+// Functor that compares XEvents for sorting by timespan.
+struct XEventsComparator {
+  bool operator()(const XEvent* a, const XEvent* b) const;
+};
+
+// Returns a sorted vector of all XEvents in the given XPlane.
+template <class Compare>
+std::vector<XEvent*> GetSortedEvents(XPlane* plane, Compare comp,
+                                     bool include_derived_events = false) {
+  std::vector<XEvent*> events;
+  for (XLine& line : *plane->mutable_lines()) {
+    if (!include_derived_events && IsDerivedThreadId(line.id())) continue;
+    for (XEvent& event : *line.mutable_events()) {
+      events.push_back(&event);
+    }
+  }
+  absl::c_sort(events, XEventsComparator());
+  return events;
+}
+
 // Normalize timestamps by time-shifting to start_time_ns_ as origin.
 void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns);
 void NormalizeTimestamps(XSpace* space, uint64 start_time_ns);
diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto
index c756644..83ba782 100644
--- a/tensorflow/core/protobuf/saved_object_graph.proto
+++ b/tensorflow/core/protobuf/saved_object_graph.proto
@@ -140,6 +140,7 @@
   VariableSynchronization synchronization = 4;
   VariableAggregation aggregation = 5;
   string name = 6;
+  string device = 7;
 }
 
 // Represents `FunctionSpec` used in `Function`. This represents a
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 00aad95..48adfec 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -108,7 +108,7 @@
 
 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 498  // Updated: 2020/8/19
+#define TF_GRAPH_DEF_VERSION 510  // Updated: 2020/8/31
 
 // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
 //
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index d8abbd0..d20b5ab 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -157,6 +157,7 @@
         ":tpu_api",
         ":tpu_compilation_device",
         ":tpu_config_c_api",
+        ":tpu_executor_init_fns",
         ":tpu_library_init_fns",
         ":tpu_node_device",
         ":tpu_system_device",
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
index 882947c..cdf32c5 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -961,7 +961,7 @@
 const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() {
   static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
       {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
-       "NextIteration", "Shape"});
+       "NextIteration", "Shape", "_Retval"});
   return *place_on_tpu_ops;
 }
 
@@ -1568,7 +1568,8 @@
     arg_shape.shape = TensorShape();  // Variables are always scalars.
     arg_shape.handle_shape = info->handle_shape;
     arg_shape.handle_type = info->handle_type;
-    TF_RET_CHECK(arg_shape.handle_type != DT_INVALID);
+    TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
+        << " input edge: " << input_edges[edge_pos]->DebugString();
     ++edge_pos;
   }
 
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 9f72c7a..8f97b2e 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -4,6 +4,7 @@
     "//tensorflow/core/platform:build_config.bzl",
     "tf_proto_library_cc",
 )
+load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")  # buildifier: disable=same-origin-load
 load(
     "//tensorflow:tensorflow.bzl",
     "tf_kernel_library",
@@ -45,8 +46,10 @@
     name = "tpu_compile_op_common",
     srcs = ["tpu_compile_op_common.cc"],
     hdrs = ["tpu_compile_op_common.h"],
-    linkopts = ["-Wl,--warn-backrefs-exclude=*/learning/brain/google/xla/_objs/tpu_compilation_metrics_google/*"],  # TODO(b/163560146) Fix the dependency issue
-    deps = [
+    deps = select({
+        WITH_TPU_SUPPORT: [":tpu_compilation_metrics"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
+    }) + [
         ":tpu_compilation_cache_entry_unloader",
         ":tpu_compilation_cache_interface",
         ":tpu_compilation_metrics_hdrs",
@@ -59,6 +62,9 @@
         ":tpu_util",
         ":tpu_util_c_api_hdrs",
         ":tpu_util_hdrs",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+        "@com_google_absl//absl/types:variant",
         "//tensorflow/compiler/jit:flags",
         "//tensorflow/compiler/jit:shape_inference",
         "//tensorflow/compiler/tf2xla:tf2xla_util",
@@ -79,9 +85,6 @@
         "//tensorflow/core/tpu:tpu_configuration",
         "//tensorflow/core/tpu:tpu_defs",
         "//tensorflow/stream_executor/tpu:tpu_platform_interface",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
-        "@com_google_absl//absl/types:variant",
     ],
     alwayslink = 1,
 )
@@ -126,7 +129,7 @@
         ":tpu_program_c_api_hdrs",
         ":tpu_util_c_api_hdrs",
         "//tensorflow/core/tpu:libtftpu_header",
-        "//tensorflow/stream_executor/tpu:proto_helper",
+        "//tensorflow/stream_executor/tpu:c_api_decl",
     ],
     alwayslink = True,
 )
@@ -234,8 +237,8 @@
         "tpu_compilation_cache_lookup.h",
     ],
     deps = [
+        ":tpu_compilation_cache_common_proto_cc",
         ":tpu_compilation_cache_interface",
-        ":tpu_compilation_cache_proto_cc",
         "//tensorflow/core/lib/core:refcount",
         "//tensorflow/core/platform:status",
         "//tensorflow/core/profiler/lib:traceme",
@@ -247,11 +250,11 @@
     srcs = ["tpu_compilation_cache_local_lookup.cc"],
     hdrs = ["tpu_compilation_cache_local_lookup.h"],
     deps = [
+        ":tpu_compilation_cache_common_proto_cc",
         ":tpu_compilation_cache_entry",
         ":tpu_compilation_cache_external",
         ":tpu_compilation_cache_interface",
         ":tpu_compilation_cache_lookup",
-        ":tpu_compilation_cache_proto_cc",
         "//tensorflow/core/platform:status",
     ],
 )
@@ -334,16 +337,22 @@
     name = "tpu_compilation_cache_interface",
     srcs = ["tpu_compilation_cache_interface.cc"],
     hdrs = ["tpu_compilation_cache_interface.h"],
-    linkopts = ["-Wl,--warn-backrefs-exclude=*/learning/brain/google/xla/_objs/tpu_compilation_metrics_google/*"],  # TODO(b/163560146) Fix the dependency issue.
-    deps = [
+    deps = select({
+        WITH_TPU_SUPPORT: [":tpu_compilation_metrics"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
+    }) + [
         ":compiled_subgraph",
+        ":tpu_compilation_cache_common_proto_cc",
         ":tpu_compilation_cache_entry",
         ":tpu_compilation_cache_key",
-        ":tpu_compilation_cache_proto_cc",
         ":tpu_compilation_metrics_hdrs",
         ":tpu_util",
         ":tpu_util_hdrs",
         ":trace_util_hdrs",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/container:node_hash_map",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
         "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/core:framework",
@@ -352,10 +361,6 @@
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/profiler/lib:traceme",
         "//tensorflow/core/tpu:tpu_api",
-        "@com_google_absl//absl/base:core_headers",
-        "@com_google_absl//absl/container:node_hash_map",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/synchronization",
     ],
     alwayslink = 1,
 )
@@ -368,10 +373,10 @@
     ],
     deps = [
         ":compiled_subgraph",
+        ":tpu_compilation_cache_common_proto_cc",
         ":tpu_compilation_cache_entry",
         ":tpu_compilation_cache_interface",
         ":tpu_compilation_cache_key",
-        ":tpu_compilation_cache_proto_cc",
         ":tpu_compilation_metrics",  # buildcleaner: keep
         ":tpu_compilation_metrics_hdrs",
         ":tpu_compile_c_api_hdrs",
@@ -489,23 +494,148 @@
     hdrs = ["tpu_util.h"],
     deps = [
         ":tpu_compilation_cache_key",
-        ":tpu_program_group_interface",
+        ":tpu_util_c_api_hdrs",
         "//tensorflow/cc:ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:compile_only_client",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/tpu:tpu_api",
         "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/time",
     ],
     alwayslink = 1,
 )
 
+cc_library(
+    name = "tpu_compilation_cache_rpc_support_hdrs",
+    hdrs = ["tpu_compilation_cache_rpc_support.h"],
+    copts = select({
+        WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
+        DEFAULT: [],
+    }),
+    deps = select({
+        WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
+    }) + [
+        ":tpu_compilation_cache_entry",
+        ":tpu_compilation_cache_interface",
+        ":tpu_compilation_cache_lookup",
+        ":tpu_program_group_interface",
+        "@com_google_absl//absl/strings",
+        "//tensorflow/core/platform:status",
+        tf_grpc_cc_dependency(),
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_rpc_support",
+    srcs = ["tpu_compilation_cache_rpc_support.cc"],
+    copts = select({
+        WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
+        DEFAULT: [],
+    }),
+    deps = [
+        ":tpu_compilation_cache_proto_cc",
+        ":tpu_compilation_cache_rpc_support_hdrs",
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_rpc_lookup",
+    srcs = ["tpu_compilation_cache_rpc_lookup.cc"],
+    hdrs = ["tpu_compilation_cache_rpc_lookup.h"],
+    copts = select({
+        WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
+        DEFAULT: [],
+    }),
+    deps = select({
+        WITH_TPU_SUPPORT: [":tpu_compilation_cache_rpc_support"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support"],
+    }) + [
+        ":tpu_compilation_cache_grpc",
+        ":tpu_compilation_cache_interface",
+        ":tpu_compilation_cache_lookup",
+        ":tpu_compilation_cache_common_proto_cc",
+        ":tpu_compilation_cache_rpc_support_hdrs",
+        ":tpu_program_group_interface",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
+        "@com_google_absl//absl/time",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+        tf_grpc_cc_dependency(),
+    ],
+)
+
 tf_proto_library_cc(
     name = "tpu_compilation_cache_proto",
     srcs = ["tpu_compilation_cache.proto"],
+    has_services = True,
     cc_api_version = 2,
+    create_java_proto = False,
+    protodeps = [
+        ":tpu_compilation_cache_common_proto",
+        "//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
+    ],
+)
+
+tf_proto_library_cc(
+    name = "tpu_compilation_cache_common_proto",
+    srcs = ["tpu_compilation_cache_common.proto"],
+    cc_api_version = 2,
+    create_java_proto = False,
+)
+
+cc_library(
+    name = "tpu_compilation_cache_grpc",
+    srcs = ["tpu_compilation_cache_grpc.cc"],
+    hdrs = ["tpu_compilation_cache_grpc.h"],
+    copts = select({
+        WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
+        DEFAULT: [],
+    }),
+    deps = select({
+        WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
+    }) + [
+        ":tpu_compilation_cache_common_proto_cc",
+        tf_grpc_cc_dependency(),
+    ],
+)
+
+cc_library(
+    name = "tpu_compilation_cache_service",
+    srcs = ["tpu_compilation_cache_service.cc"],
+    hdrs = ["tpu_compilation_cache_service.h"],
+    copts = select({
+        WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
+        DEFAULT: [],
+    }),
+    deps = select({
+        WITH_TPU_SUPPORT: [
+            ":tpu_compilation_cache_rpc_support",
+            ":tpu_compilation_cache_proto_cc",
+        ],
+        DEFAULT: [
+            "//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support",
+            "//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc",
+        ],
+    }) + [
+        ":tpu_compilation_cache_common_proto_cc",
+        ":tpu_compilation_cache_entry",
+        ":tpu_compilation_cache_grpc",
+        ":tpu_compilation_cache_interface",
+        ":tpu_compilation_cache_rpc_support_hdrs",
+        "@com_google_absl//absl/base:core_headers",
+        "@com_google_absl//absl/synchronization",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_call",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+        "//tensorflow/core/lib/core:refcount",
+        "//tensorflow/core/lib/core:threadpool",
+        "//tensorflow/core/platform:coding",
+        "//tensorflow/core:protos_all_cc",
+        tf_grpc_cc_dependency(),
+    ],
 )
 
 cc_library(
@@ -795,3 +925,19 @@
     ],
     alwayslink = True,
 )
+
+cc_library(
+    name = "tpu_pod_state",
+    srcs = ["tpu_pod_state.cc"],
+    hdrs = ["tpu_pod_state.h"],
+    copts = select({
+        WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
+        DEFAULT: [],
+    }),
+    deps = [
+        ":tpu_compilation_cache_service",
+        ":tpu_util",
+        "//tensorflow/core:framework",
+        tf_grpc_cc_dependency(),
+    ],
+)
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
index 8308cba..f452922 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
@@ -16,10 +16,27 @@
 
 package tensorflow.tpu;
 
-// Target type for compilation cache fetch operation.
-enum CompilationCacheFetchTarget {
-  INVALID = 0;
-  MAIN = 1;
-  SHARDING = 2;
-  UNSHARDING = 3;
+import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
+import "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto";
+
+// Response for GetTpuProgram RPC.
+message GetTpuProgramResponseExternal {
+  message Blob {
+    bytes data = 1;
+  }
+
+  Blob proto = 1;
+  tf2xla.HostComputeMetadata host_compute_metadata = 2;
+  bool may_modify_variables = 3;
+  Blob compiler_metadata = 4;
+  // Whether the program is empty, which could be true for sharding/unsharding
+  // entries.
+  bool is_empty = 5;
+}
+
+service TpuCompilationCacheServiceExternal {
+  // This method requests the cached proto that the TPU execute op has been
+  // instructed to execute.
+  rpc GetTpuProgram(GetTpuProgramRequest)
+      returns (GetTpuProgramResponseExternal) {}
 }
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto b/tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto
new file mode 100644
index 0000000..89b92ae
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// Target type for compilation cache fetch operation.
+enum CompilationCacheFetchTarget {
+  INVALID = 0;
+  MAIN = 1;
+  SHARDING = 2;
+  UNSHARDING = 3;
+}
+
+message TpuCompilationUidAndIndex {
+  int64 uid = 1;
+  int32 proto_index = 2;
+}
+
+message GetTpuProgramRequest {
+  oneof key_oneof {
+    string key = 1;
+    TpuCompilationUidAndIndex uid_and_index = 2;
+  }
+  CompilationCacheFetchTarget fetch_target = 3;
+}
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
index 51b5ffb..c3f95e7 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
@@ -30,7 +30,7 @@
 #include "tensorflow/core/platform/refcount.h"
 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
 #include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc
new file mode 100644
index 0000000..207a60e
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc
@@ -0,0 +1,107 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h"
+
+#include <grpcpp/impl/codegen/async_stream.h>
+#include <grpcpp/impl/codegen/async_unary_call.h>
+#include <grpcpp/impl/codegen/channel_interface.h>
+#include <grpcpp/impl/codegen/client_callback.h>
+#include <grpcpp/impl/codegen/client_unary_call.h>
+#include <grpcpp/impl/codegen/method_handler.h>
+#include <grpcpp/impl/codegen/rpc_service_method.h>
+#include <grpcpp/impl/codegen/server_callback.h>
+#include <grpcpp/impl/codegen/service_type.h>
+#include <grpcpp/impl/codegen/sync_stream.h>
+
+#include <functional>
+namespace tensorflow {
+namespace tpu {
+
+static const char* grpcTpuCompilationCacheService_method_names[] = {
+#if defined(LIBTFTPU)
+    "/tensorflow.tpu.TpuCompilationCacheServiceExternal/GetTpuProgram",
+#else  // LIBTFTPU
+    "/tensorflow.tpu.TpuCompilationCacheService/GetTpuProgram",
+#endif  // LIBTFTPU
+};
+
+std::unique_ptr<grpc::TpuCompilationCacheService::Stub>
+grpc::TpuCompilationCacheService::NewStub(
+    const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+    const ::grpc::StubOptions& options) {
+  (void)options;
+  std::unique_ptr<grpc::TpuCompilationCacheService::Stub> stub(
+      new grpc::TpuCompilationCacheService::Stub(channel));
+  return stub;
+}
+
+grpc::TpuCompilationCacheService::Stub::Stub(
+    const std::shared_ptr< ::grpc::ChannelInterface>& channel)
+    : channel_(channel),
+      rpcmethod_get_tpu_program_(grpcTpuCompilationCacheService_method_names[0],
+                                 ::grpc::internal::RpcMethod::NORMAL_RPC,
+                                 channel) {}
+
+::grpc::Status grpc::TpuCompilationCacheService::Stub::GetTpuProgram(
+    ::grpc::ClientContext* context, const RequestType& request,
+    ResponseType* response) {
+  return ::grpc::internal::BlockingUnaryCall(
+      channel_.get(), rpcmethod_get_tpu_program_, context, request, response);
+}
+
+::grpc::ClientAsyncResponseReader<
+    grpc::TpuCompilationCacheService::ResponseType>*
+grpc::TpuCompilationCacheService::Stub::AsyncGetTpuProgramRaw(
+    ::grpc::ClientContext* context, const RequestType& request,
+    ::grpc::CompletionQueue* cq) {
+  return ::grpc::internal::ClientAsyncResponseReaderFactory<
+      ResponseType>::Create(channel_.get(), cq, rpcmethod_get_tpu_program_,
+                            context, request, true);
+}
+
+::grpc::ClientAsyncResponseReader<
+    grpc::TpuCompilationCacheService::ResponseType>*
+grpc::TpuCompilationCacheService::Stub::PrepareAsyncGetTpuProgramRaw(
+    ::grpc::ClientContext* context, const RequestType& request,
+    ::grpc::CompletionQueue* cq) {
+  return ::grpc::internal::ClientAsyncResponseReaderFactory<
+      ResponseType>::Create(channel_.get(), cq, rpcmethod_get_tpu_program_,
+                            context, request, false);
+}
+
+grpc::TpuCompilationCacheService::Service::Service() {
+  AddMethod(new ::grpc::internal::RpcServiceMethod(
+      grpcTpuCompilationCacheService_method_names[0],
+      ::grpc::internal::RpcMethod::NORMAL_RPC,
+      new ::grpc::internal::RpcMethodHandler<
+          grpc::TpuCompilationCacheService::Service, RequestType, ResponseType>(
+          std::mem_fn(
+              &grpc::TpuCompilationCacheService::Service::GetTpuProgram),
+          this)));
+}
+
+grpc::TpuCompilationCacheService::Service::~Service() {}
+
+::grpc::Status grpc::TpuCompilationCacheService::Service::GetTpuProgram(
+    ::grpc::ServerContext* context, const RequestType* request,
+    ResponseType* response) {
+  (void)context;
+  (void)request;
+  (void)response;
+  return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
+}
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h
new file mode 100644
index 0000000..324fc9e
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h
@@ -0,0 +1,235 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Copied from auto-generated gRPC code in order to enable using grpc_call.h
+// for raw message handling.
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_
+
+#include <grpcpp/impl/codegen/async_generic_service.h>
+#include <grpcpp/impl/codegen/async_stream.h>
+#include <grpcpp/impl/codegen/async_unary_call.h>
+#include <grpcpp/impl/codegen/client_callback.h>
+#include <grpcpp/impl/codegen/client_context.h>
+#include <grpcpp/impl/codegen/completion_queue.h>
+#include <grpcpp/impl/codegen/method_handler.h>
+#include <grpcpp/impl/codegen/proto_utils.h>
+#include <grpcpp/impl/codegen/rpc_method.h>
+#include <grpcpp/impl/codegen/server_callback.h>
+#include <grpcpp/impl/codegen/server_context.h>
+#include <grpcpp/impl/codegen/service_type.h>
+#include <grpcpp/impl/codegen/status.h>
+#include <grpcpp/impl/codegen/stub_options.h>
+#include <grpcpp/impl/codegen/sync_stream.h>
+
+#include <functional>
+
+#if defined(LIBTFTPU)
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#else
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"  // copybara"
+#endif
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
+
+namespace tensorflow {
+namespace tpu {
+namespace grpc {
+class TpuCompilationCacheService final {
+ public:
+  using RequestType = ::tensorflow::tpu::GetTpuProgramRequest;
+#if defined(LIBTFTPU)
+  using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal;
+#else
+  using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse;
+#endif
+
+  // N.B. This must be synchronized with the method order in
+  // tpu_compilation_cache.proto.
+  enum class MethodId { kGetTpuProgram = 0 };
+
+  static constexpr char const* service_full_name() {
+#if defined(LIBTFTPU)
+    return "tensorflow.tpu.TpuCompilationCacheServiceExternal";
+#else
+    return "tensorflow.tpu.TpuCompilationCacheService";
+#endif
+  }
+  class StubInterface {
+   public:
+    virtual ~StubInterface() {}
+    // This method requests the cached proto that the TPU execute op has
+    // been instructed to execute.
+    virtual ::grpc::Status GetTpuProgram(::grpc::ClientContext* context,
+                                         const RequestType& request,
+                                         ResponseType* response) = 0;
+    std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface<ResponseType>>
+    AsyncGetTpuProgram(::grpc::ClientContext* context,
+                       const RequestType& request,
+                       ::grpc::CompletionQueue* cq) {
+      return std::unique_ptr<
+          ::grpc::ClientAsyncResponseReaderInterface<ResponseType>>(
+          AsyncGetTpuProgramRaw(context, request, cq));
+    }
+    std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface<ResponseType>>
+    PrepareAsyncGetTpuProgram(::grpc::ClientContext* context,
+                              const RequestType& request,
+                              ::grpc::CompletionQueue* cq) {
+      return std::unique_ptr<
+          ::grpc::ClientAsyncResponseReaderInterface<ResponseType>>(
+          PrepareAsyncGetTpuProgramRaw(context, request, cq));
+    }
+
+   private:
+    virtual ::grpc::ClientAsyncResponseReaderInterface<ResponseType>*
+    AsyncGetTpuProgramRaw(::grpc::ClientContext* context,
+                          const RequestType& request,
+                          ::grpc::CompletionQueue* cq) = 0;
+    virtual ::grpc::ClientAsyncResponseReaderInterface<ResponseType>*
+    PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context,
+                                 const RequestType& request,
+                                 ::grpc::CompletionQueue* cq) = 0;
+  };
+  class Stub final : public StubInterface {
+   public:
+    explicit Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel);
+    ::grpc::Status GetTpuProgram(::grpc::ClientContext* context,
+                                 const RequestType& request,
+                                 ResponseType* response) override;
+    std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>
+    AsyncGetTpuProgram(::grpc::ClientContext* context,
+                       const RequestType& request,
+                       ::grpc::CompletionQueue* cq) {
+      return std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>(
+          AsyncGetTpuProgramRaw(context, request, cq));
+    }
+    std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>
+    PrepareAsyncGetTpuProgram(::grpc::ClientContext* context,
+                              const RequestType& request,
+                              ::grpc::CompletionQueue* cq) {
+      return std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>(
+          PrepareAsyncGetTpuProgramRaw(context, request, cq));
+    }
+
+   private:
+    std::shared_ptr<::grpc::ChannelInterface> channel_;
+    ::grpc::ClientAsyncResponseReader<ResponseType>* AsyncGetTpuProgramRaw(
+        ::grpc::ClientContext* context, const RequestType& request,
+        ::grpc::CompletionQueue* cq) override;
+    ::grpc::ClientAsyncResponseReader<ResponseType>*
+    PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context,
+                                 const RequestType& request,
+                                 ::grpc::CompletionQueue* cq) override;
+    const ::grpc::internal::RpcMethod rpcmethod_get_tpu_program_;
+  };
+  static std::unique_ptr<Stub> NewStub(
+      const std::shared_ptr<::grpc::ChannelInterface>& channel,
+      const ::grpc::StubOptions& options = ::grpc::StubOptions());
+
+  class Service : public ::grpc::Service {
+   public:
+    Service();
+    ~Service() override;
+    // This method requests the cached proto that the TPU execute op has
+    // been instructed to execute.
+    virtual ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
+                                         const RequestType* request,
+                                         ResponseType* response);
+  };
+  template <class BaseClass>
+  class WithAsyncMethod_GetTpuProgram : public BaseClass {
+   private:
+    void BaseClassMustBeDerivedFromService(const Service* service) {}
+
+   public:
+    WithAsyncMethod_GetTpuProgram() { ::grpc::Service::MarkMethodAsync(0); }
+    ~WithAsyncMethod_GetTpuProgram() override {
+      BaseClassMustBeDerivedFromService(this);
+    }
+    // disable synchronous version of this method
+    ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
+                                 const RequestType* request,
+                                 ResponseType* response) override {
+      abort();
+      return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
+    }
+    void RequestGetTpuProgram(
+        ::grpc::ServerContext* context, RequestType* request,
+        ::grpc::ServerAsyncResponseWriter<ResponseType>* response,
+        ::grpc::CompletionQueue* new_call_cq,
+        ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+      ::grpc::Service::RequestAsyncUnary(0, context, request, response,
+                                         new_call_cq, notification_cq, tag);
+    }
+
+    // Make RequestAsyncUnary accessible to grpc_call.h
+    using ::grpc::Service::RequestAsyncUnary;
+  };
+  typedef WithAsyncMethod_GetTpuProgram<Service> AsyncService;
+  template <class BaseClass>
+  class WithGenericMethod_GetTpuProgram : public BaseClass {
+   private:
+    void BaseClassMustBeDerivedFromService(const Service* service) {}
+
+   public:
+    WithGenericMethod_GetTpuProgram() { ::grpc::Service::MarkMethodGeneric(0); }
+    ~WithGenericMethod_GetTpuProgram() override {
+      BaseClassMustBeDerivedFromService(this);
+    }
+    // disable synchronous version of this method
+    ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
+                                 const RequestType* request,
+                                 ResponseType* response) override {
+      abort();
+      return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
+    }
+  };
+  template <class BaseClass>
+  class WithStreamedUnaryMethod_GetTpuProgram : public BaseClass {
+   private:
+    void BaseClassMustBeDerivedFromService(const Service* service) {}
+
+   public:
+    WithStreamedUnaryMethod_GetTpuProgram() {
+      ::grpc::Service::MarkMethodStreamed(
+          0,
+          new ::grpc::internal::StreamedUnaryHandler<RequestType, ResponseType>(
+              std::bind(&WithStreamedUnaryMethod_GetTpuProgram<
+                            BaseClass>::StreamedGetTpuProgram,
+                        this, std::placeholders::_1, std::placeholders::_2)));
+    }
+    ~WithStreamedUnaryMethod_GetTpuProgram() override {
+      BaseClassMustBeDerivedFromService(this);
+    }
+    // disable regular version of this method
+    ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
+                                 const RequestType* request,
+                                 ResponseType* response) override {
+      abort();
+      return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
+    }
+    // replace default version of method with streamed unary
+    virtual ::grpc::Status StreamedGetTpuProgram(
+        ::grpc::ServerContext* context,
+        ::grpc::ServerUnaryStreamer<RequestType, ResponseType>*
+            server_unary_streamer) = 0;
+  };
+  typedef WithStreamedUnaryMethod_GetTpuProgram<Service> StreamedUnaryService;
+  typedef Service SplitStreamedService;
+  typedef WithStreamedUnaryMethod_GetTpuProgram<Service> StreamedService;
+};
+}  // namespace grpc
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
index 7b206fb..e1e7cf2 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h
@@ -31,7 +31,7 @@
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h
index 8db4c11..96f9235 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h
@@ -16,9 +16,8 @@
 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOCAL_LOOKUP_H_
 
 #include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
 
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
index ab47632..fc81970 100644
--- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
@@ -17,7 +17,7 @@
 
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc
new file mode 100644
index 0000000..e3560de
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc
@@ -0,0 +1,202 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
+
+#include <grpcpp/security/credentials.h>
+
+#include "absl/strings/str_cat.h"
+#include "absl/time/time.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
+
+namespace tensorflow {
+namespace tpu {
+namespace {
+
+#if defined(LIBTFTPU)
+using ResponseType = GetTpuProgramResponseExternal;
+#else
+using ResponseType = GetTpuProgramResponse;
+#endif
+
+static constexpr absl::Duration kProtoTimeout = absl::Minutes(15);
+static gpr_timespec TimeToGprTimespec(absl::Time time) {
+  if (time == absl::InfiniteFuture()) {
+    return gpr_inf_future(GPR_CLOCK_REALTIME);
+  }
+  if (time == absl::InfinitePast()) {
+    return gpr_inf_past(GPR_CLOCK_REALTIME);
+  }
+
+  gpr_timespec spec;
+  timespec t = absl::ToTimespec(time);
+  spec.tv_sec = t.tv_sec;
+  spec.tv_nsec = static_cast<int32_t>(t.tv_nsec);
+  spec.clock_type = GPR_CLOCK_REALTIME;
+  return spec;
+}
+}  // namespace
+TpuCompilationCacheRpcLookup::TpuCompilationCacheRpcLookup(
+    const std::string& server_address, int64 max_cache_size)
+    : max_cache_size_(max_cache_size) {
+  // Ensure that large TPU program can get sent over the channel.
+  ::grpc::ChannelArguments args;
+  args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
+  auto channel =
+      ::grpc::CreateCustomChannel(absl::StrCat("dns:///", server_address),
+                                  CreateChannelCredentials(), args);
+  stub_ = tpu::grpc::TpuCompilationCacheService::NewStub(channel);
+  VLOG(1) << "Created RPC lookup cache size " << max_cache_size_ << " bytes.";
+}
+
+Status TpuCompilationCacheRpcLookup::Lookup(
+    const std::string& proto_key,
+    std::unique_ptr<CompilationCacheEntryRef>* entry,
+    tpu::CompilationCacheFetchTarget fetch_target) {
+  profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache lookup",
+                                         /*level=*/2);
+  entry->reset();
+  std::shared_ptr<CacheEntry> cache_entry;
+  // Keep a reference to CacheEntry objects evicted from the cache so that the
+  // potential deletion happens outside the lock upon method exit.
+  std::vector<std::shared_ptr<CacheEntry>> removed_entries;
+
+  std::string local_proto_key = absl::StrCat(
+      proto_key, "_", tpu::CompilationCacheFetchTarget_Name(fetch_target));
+
+  {
+    absl::MutexLock lock(&mu_);
+    auto iter = cache_.find(local_proto_key);
+    if (iter == cache_.end()) {
+      tpu::GetTpuProgramRequest request;
+      request.set_key(proto_key);
+      request.set_fetch_target(fetch_target);
+      TF_RETURN_IF_ERROR(
+          RemoteLookupLocked(local_proto_key, request, &cache_entry));
+    } else {
+      VLOG(1) << "Found key " << local_proto_key << " in local proto cache.";
+      cache_entry = iter->second;
+      auto erased = entries_by_last_use_.erase(cache_entry->last_use);
+      CHECK_EQ(erased, 1);
+    }
+    PostLookupLocked(&cache_entry, entry, &removed_entries);
+  }
+  return Status::OK();
+}
+
+Status TpuCompilationCacheRpcLookup::Lookup(
+    int64 uid, int proto_index,
+    std::unique_ptr<CompilationCacheEntryRef>* entry,
+    tpu::CompilationCacheFetchTarget fetch_target) {
+  profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache lookup by uid",
+                                         /*level=*/2);
+  entry->reset();
+  std::shared_ptr<CacheEntry> cache_entry;
+  // Keep a reference to CacheEntry objects evicted from the cache so that the
+  // potential deletion happens outside the lock upon method exit.
+  std::vector<std::shared_ptr<CacheEntry>> removed_entries;
+
+  // Make a string key so that we can uniformly store cached entries under
+  // string keys whether they are looked up by proto_key or uid+index. The
+  // expectation is that any given executable will only ever be looked up
+  // *either* by proto_key *or* by uid+index, so we are not concerned that the
+  // same proto could be placed in the cache twice if it is looked up by both
+  // methods.
+  std::string local_proto_key =
+      absl::StrCat(" _ ", uid, ":", proto_index, "_",
+                   tpu::CompilationCacheFetchTarget_Name(fetch_target));
+  {
+    absl::MutexLock lock(&mu_);
+    auto iter = cache_.find(local_proto_key);
+    if (iter == cache_.end()) {
+      tpu::GetTpuProgramRequest request;
+      tpu::TpuCompilationUidAndIndex* uid_and_index =
+          request.mutable_uid_and_index();
+      uid_and_index->set_uid(uid);
+      uid_and_index->set_proto_index(proto_index);
+      request.set_fetch_target(fetch_target);
+      TF_RETURN_IF_ERROR(
+          RemoteLookupLocked(local_proto_key, request, &cache_entry));
+    } else {
+      VLOG(1) << "Found uid " << uid << " and index " << proto_index
+              << " in local proto cache.";
+      cache_entry = iter->second;
+      auto erased = entries_by_last_use_.erase(cache_entry->last_use);
+      CHECK_EQ(erased, 1);
+    }
+    PostLookupLocked(&cache_entry, entry, &removed_entries);
+  }
+  return Status::OK();
+}
+
+Status TpuCompilationCacheRpcLookup::RemoteLookupLocked(
+    const std::string& local_proto_key,
+    const tpu::GetTpuProgramRequest& request,
+    std::shared_ptr<CacheEntry>* cache_entry) {
+  profiler::TraceMe proto_lookup_traceme("Remote TPU proto cache fetch",
+                                         /*level=*/2);
+  // Perform the RPC while holding the lock unless it is demonstrated that
+  // this causes a performance problem.
+  ::grpc::ClientContext client_context;
+  client_context.set_deadline(TimeToGprTimespec(::absl::Now() + kProtoTimeout));
+  client_context.set_compression_algorithm(GRPC_COMPRESS_GZIP);
+
+  ResponseType response;
+  Status s =
+      FromGrpcStatus(stub_->GetTpuProgram(&client_context, request, &response));
+  VLOG(1) << "Looked up key " << local_proto_key
+          << " in remote subgraph cache status " << s;
+  TF_RETURN_IF_ERROR(s);
+
+  TF_RETURN_IF_ERROR(FillCacheEntryFromGetTpuProgramResponse(
+      local_proto_key, &response, cache_entry));
+  cache_.emplace(local_proto_key, (*cache_entry));
+  cache_size_ += (*cache_entry)->size;
+
+  return Status::OK();
+}
+
+void TpuCompilationCacheRpcLookup::PostLookupLocked(
+    std::shared_ptr<CacheEntry>* cache_entry,
+    std::unique_ptr<CompilationCacheEntryRef>* entry,
+    std::vector<std::shared_ptr<CacheEntry>>* removed_entries) {
+  (*cache_entry)->last_use = use_counter_++;
+  entries_by_last_use_[(*cache_entry)->last_use] = cache_entry->get();
+  *entry =
+      std::unique_ptr<CompilationCacheEntryRef>(new CacheWrapper(*cache_entry));
+
+  // Evict overflowing entries if necessary, but never evict the most recently
+  // used entry.
+  while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
+    auto entry_to_evict = entries_by_last_use_.begin()->second;
+    entries_by_last_use_.erase(entry_to_evict->last_use);
+    CHECK_GE(cache_size_, entry_to_evict->size);
+    cache_size_ -= entry_to_evict->size;
+    // Delete the cache's reference to the entry, though clients may still be
+    // holding onto references. We use 'removed_entries' to delay the possible
+    // CacheEntry destruction until the mu_ lock is released.
+    auto entry_to_evict_it = cache_.find(entry_to_evict->key);
+    CHECK(entry_to_evict_it != cache_.end())
+        << "Missing entry key: " << entry_to_evict->key;
+    removed_entries->push_back(entry_to_evict_it->second);
+    cache_.erase(entry_to_evict_it);
+  }
+}
+
+std::string TpuCompilationCacheRpcLookup::DebugString() const {
+  return "TpuCompilationCacheRpcLookup";
+}
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h
new file mode 100644
index 0000000..d5449a0
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h
@@ -0,0 +1,95 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_
+
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
+#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Class for looking up and caching TPU program via RPC.
+class TpuCompilationCacheRpcLookup : public TpuCompilationCacheLookup {
+ public:
+  using StubType = tpu::grpc::TpuCompilationCacheService::Stub;
+
+  TpuCompilationCacheRpcLookup(const string& server_address,
+                               int64 max_cache_size);
+  ~TpuCompilationCacheRpcLookup() override = default;
+
+  Status Lookup(const string& proto_key,
+                std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
+                tpu::CompilationCacheFetchTarget fetch_target) override;
+
+  Status Lookup(int64 uid, int proto_index,
+                std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
+                tpu::CompilationCacheFetchTarget fetch_target) override;
+
+  string DebugString() const override;
+
+ private:
+  // Helper method to make the RPC request to the central cache.
+  Status RemoteLookupLocked(const string& local_proto_key,
+                            const tpu::GetTpuProgramRequest& request,
+                            std::shared_ptr<CacheEntry>* cache_entry)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Helper method to adjust datastructures after a cache lookup.
+  // We use `removed_entries` so that actual CacheEntry destruction happens
+  // outside the lock.
+  void PostLookupLocked(
+      std::shared_ptr<CacheEntry>* cache_entry,
+      std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
+      std::vector<std::shared_ptr<CacheEntry>>* removed_entries)
+      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // The maximum size of entries that are stored in the cache before entries are
+  // evicted.
+  const int64 max_cache_size_;
+
+  std::unique_ptr<StubType> stub_;
+
+  // Protect concurrent access to member variables below.
+  mutable absl::Mutex mu_;
+
+  // The total size of entries in the cache.
+  int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
+  // The value to assign to the last_use field of the next entry that is looked
+  // up.
+  int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
+  // The entries that can be looked up in the cache. An entry is deleted from
+  // the cache as soon as it is evicted, but the underlying shared_ptr won't be
+  // freed until any wrappers holding it go out of scope.
+  std::unordered_map<std::string, std::shared_ptr<CacheEntry>> cache_
+      ABSL_GUARDED_BY(mu_);
+  // Map from last_use to entry, used to evict entries in LRU order.
+  std::map<int64, CacheEntry*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
+};
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc
new file mode 100644
index 0000000..0e77edf
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc
@@ -0,0 +1,44 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
+
+#if defined(LIBTFTPU)
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
+#endif  // LIBTFTPU
+
+namespace tensorflow {
+namespace tpu {
+std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
+  return ::grpc::InsecureChannelCredentials();  // NOLINT
+}
+
+#if defined(LIBTFTPU)
+template <>
+Status FillCacheEntryFromGetTpuProgramResponse<GetTpuProgramResponseExternal>(
+    absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
+    std::shared_ptr<CacheEntry>* cache_entry) {
+  // TODO(b/162904194): implement this method.
+  LOG(FATAL) << "Not implemented yet.";
+}
+
+void SendGetTpuProgramResponseHelper(
+    const TpuCompilationCacheEntry& cache_entry,
+    std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn) {
+  // TODO(b/162904194): implement this method.
+  LOG(FATAL) << "Not implemented yet.";
+}
+#endif  // LIBTFTPU
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h
new file mode 100644
index 0000000..6749138
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h
@@ -0,0 +1,96 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_
+
+#include <grpcpp/security/credentials.h>
+
+#include <functional>
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
+#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// A cache entry for remote TPU compilation.
+struct CacheEntry {
+  CacheEntry() : size(0), last_use(-1) {}
+  virtual ~CacheEntry() {
+    if (tpu_program_group != nullptr) {
+      tpu_program_group->UnloadAndDestroyPrograms();
+    }
+  }
+  std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
+  std::string key;
+  int64 size;
+
+  // An integer-based monotonically increasing counter used by the TPU
+  // compilation cache to sort and evict the least recently used entry when the
+  // cache size exceeded the maximum size limit. The value is initialized to
+  // `-1` as an initial value.
+  int64 last_use;
+};
+
+// Implementation of `CompilationCacheEntryRef` that holds a shared_ptr to the
+// local cache entry until the wrapper is destroyed.
+class CacheWrapper : public CompilationCacheEntryRef {
+ public:
+  explicit CacheWrapper(std::shared_ptr<CacheEntry> entry)
+      : cache_entry_(std::move(entry)) {}
+  ~CacheWrapper() override = default;
+
+  TpuCompilationCacheEntry get() override {
+    if (cache_entry_->size == 0) {
+      // Create an empty entry if the size is 0. This corresponds to
+      // non-existing sharding/unsharding entries.
+      return TpuCompilationCacheEntry();
+    }
+    return TpuCompilationCacheEntry(cache_entry_->tpu_program_group.get(),
+                                    /*core_index=*/0);
+  }
+
+  Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override {
+    LOG(FATAL) << "Not implemented by designed.";
+  }
+
+ private:
+  std::shared_ptr<CacheEntry> cache_entry_;
+};
+
+// Creates gRPC channel credentials for the current runtime env.
+std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials();
+
+// Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The
+// `cache_entry` will be instantiated by the function.
+template <typename ResponseType>
+Status FillCacheEntryFromGetTpuProgramResponse(
+    const absl::string_view local_proto_key, ResponseType* response,
+    std::shared_ptr<CacheEntry>* cache_entry);
+
+// A helper to send `TpuCompilationCacheEntry` payload through gRPC channel.
+void SendGetTpuProgramResponseHelper(
+    const TpuCompilationCacheEntry& cache_entry,
+    std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn);
+}  // namespace tpu
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.cc
new file mode 100644
index 0000000..5abd0c7
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.cc
@@ -0,0 +1,171 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h"
+
+#include <chrono>  // NOLINT
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/platform/coding.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
+
+namespace tensorflow {
+namespace {
+using ::tensorflow::tpu::CompilationCacheEntryRef;
+using ::tensorflow::tpu::TpuCompilationCacheEntry;
+using ::tensorflow::tpu::TpuCompilationCacheInterface;
+
+static constexpr int kGetTpuProgramServingThreads = 32;
+}  // namespace
+
+TpuCompilationCacheService::TpuCompilationCacheService(
+    ::grpc::ServerBuilder* server_builder, TpuCompilationCacheInterface* cache)
+    : cache_(cache),
+      server_builder_(server_builder),
+      cq_(server_builder_->AddCompletionQueue()),
+      thread_pool_(absl::make_unique<thread::ThreadPool>(
+          Env::Default(), "TpuCompilationCacheService",
+          kGetTpuProgramServingThreads)) {
+  cache_->Ref();
+  server_builder_->RegisterService(&service_);
+}
+
+TpuCompilationCacheService::~TpuCompilationCacheService() {
+  // This ordering is important. We must first shutdown our CQ and allow the
+  // polling thread and dispatch pool to shutdown before releasing our cache
+  // reference. The gRPC server must be Shutdown() by this point or we will
+  // deadlock here.  The running_ boolean is necessary to avoid adding new
+  // operations to the CQ after is has shutdown.
+  running_ = false;
+  cq_->Shutdown();
+  polling_thread_.reset();
+  thread_pool_.reset();
+  cache_->Unref();
+}
+
+void TpuCompilationCacheService::Start() {
+  server_ = server_builder_->BuildAndStart();
+  ThreadOptions opts;
+  polling_thread_.reset(Env::Default()->StartThread(
+      opts, "TpuCompilationCachePoller", [this]() { HandleRPCsLoop(); }));
+}
+
+bool TpuCompilationCacheService::Shutdown(int timeout_sec) {
+  if (server_ != nullptr) {
+    std::chrono::system_clock::time_point timeout =
+        std::chrono::system_clock::now() + std::chrono::seconds(timeout_sec);
+    server_->Shutdown(std::chrono::system_clock::now() +
+                      std::chrono::seconds(timeout_sec));
+    if (std::chrono::system_clock::now() >= timeout) {
+      return false;
+    }
+    return true;
+  } else {
+    return false;
+  }
+}
+
+void TpuCompilationCacheService::SetMemoryQuota(size_t max_bytes) {
+  ::grpc::ResourceQuota quota;
+  quota.Resize(max_bytes);
+  server_builder_->SetResourceQuota(quota);
+}
+
+// Fetch a cache result for the given request and serialize the result directly
+// into a ByteBuffer.
+void TpuCompilationCacheService::GetTpuProgram(GetTpuProgramCall* call) {
+  std::unique_ptr<CompilationCacheEntryRef> entry;
+
+  VLOG(1) << "GetTpuProgram: " << call->request.DebugString();
+  Status s;
+  switch (call->request.key_oneof_case()) {
+    case tpu::GetTpuProgramRequest::kKey:
+      s = cache_->Lookup(call->request.key(), &entry);
+      break;
+
+    case tpu::GetTpuProgramRequest::kUidAndIndex:
+      s = cache_->Lookup(call->request.uid_and_index().uid(),
+                         call->request.uid_and_index().proto_index(), &entry);
+      break;
+
+    default:
+      s = errors::Internal("Bad GetTpuProgram RPC request oneof case ",
+                           call->request.key_oneof_case());
+      break;
+  }
+  if (!s.ok()) {
+    return call->SendResponse(ToGrpcStatus(s));
+  }
+
+  s = entry->ToSubEntryRef(call->request.fetch_target());
+  if (!s.ok()) {
+    return call->SendResponse(::grpc::Status(
+        ::grpc::StatusCode::INVALID_ARGUMENT,
+        absl::StrCat(
+            "Error getting the fetching target ",
+            CompilationCacheFetchTarget_Name(call->request.fetch_target())),
+        s.error_message()));
+  }
+
+  TpuCompilationCacheEntry cache_entry = entry->get();
+  if (cache_entry.tpu_program_group() == nullptr) {
+    // It's possible that the sharding/unsharding entry does not exist, but the
+    // main entry must exist.
+    CHECK_NE(call->request.fetch_target(),
+             tpu::CompilationCacheFetchTarget::MAIN);
+  }
+  return SendGetTpuProgramResponseHelper(
+      cache_entry,
+      [&call](::grpc::ByteBuffer* buffer, ::grpc::Status error_status) {
+        if (buffer == nullptr) {
+          return call->SendResponse(error_status);
+        }
+        call->response = *buffer;
+        return call->SendResponse(::grpc::Status());
+      });
+}
+
+void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) {
+  thread_pool_->Schedule([this, call]() { GetTpuProgram(call); });
+  if (running_) {
+    GetTpuProgramCall::EnqueueRequestForMethod(
+        &service_, cq_.get(),
+        static_cast<int>(ServiceType::MethodId::kGetTpuProgram),
+        &TpuCompilationCacheService::HandleGetTpuProgram,
+        /*supports_cancel=*/false);
+  }
+}
+
+void TpuCompilationCacheService::HandleRPCsLoop() {
+  void* tag;
+  bool ok;
+
+  for (int i = 0; i < 50; ++i) {
+    GetTpuProgramCall::EnqueueRequestForMethod(
+        &service_, cq_.get(),
+        static_cast<int>(ServiceType::MethodId::kGetTpuProgram),
+        &TpuCompilationCacheService::HandleGetTpuProgram,
+        /*supports_cancel=*/false);
+  }
+
+  while (cq_->Next(&tag, &ok)) {
+    VLOG(2) << "HandleRPCS: " << tag;
+    UntypedCall<TpuCompilationCacheService>::Tag* callback_tag =
+        static_cast<UntypedCall<TpuCompilationCacheService>::Tag*>(tag);
+    callback_tag->OnCompleted(this, ok);
+  }
+
+  VLOG(2) << "Cache thread shutting down.";
+}
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h
new file mode 100644
index 0000000..6682be8
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h
@@ -0,0 +1,70 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SERVICE_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SERVICE_H_
+
+#include <atomic>
+#include <memory>
+
+#include "grpcpp/server_builder.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
+
+namespace tensorflow {
+// gRPC service for handling CompilationCache requests.
+// To avoid OOMs during execution, this service using the asynchronous raw gRPC
+// interface to serialize cache results directly to gRPC byte buffers. This
+// allows us to control serialization concurrency and avoids making an extra
+// copy of the program cache for each worker.
+class TpuCompilationCacheService {
+ public:
+  using ServiceType = ::tensorflow::tpu::grpc::TpuCompilationCacheService;
+  using AsyncService = ServiceType::AsyncService;
+
+  TpuCompilationCacheService(::grpc::ServerBuilder* server_builder,
+                             tpu::TpuCompilationCacheInterface* cache);
+  ~TpuCompilationCacheService();
+
+  void Start();
+  bool Shutdown(int timeout_sec);
+  void SetMemoryQuota(size_t max_bytes);
+
+ private:
+  void HandleRPCsLoop();
+
+  using GetTpuProgramCall = Call<TpuCompilationCacheService, AsyncService,
+                                 tpu::GetTpuProgramRequest, ::grpc::ByteBuffer>;
+
+  // Schedule the cache fetch into the serving thread pool.
+  void HandleGetTpuProgram(GetTpuProgramCall* call);
+
+  // Performs the actual cache fetch and serialization.
+  void GetTpuProgram(GetTpuProgramCall* call);
+
+  std::atomic<bool> running_ = true;
+  tpu::TpuCompilationCacheInterface* cache_;
+  ::grpc::ServerBuilder* server_builder_;
+  std::unique_ptr<::grpc::Server> server_;
+  std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
+  std::unique_ptr<thread::ThreadPool> thread_pool_;
+  std::unique_ptr<Thread> polling_thread_;
+  AsyncService service_;
+};
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SERVICE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h
index 4460763..0137b3f 100644
--- a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h
+++ b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h
@@ -15,11 +15,12 @@
 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
 
+#include <stddef.h>
+
 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
 #include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
-#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
 #include "tensorflow/core/tpu/libtftpu.h"
-#include "tensorflow/stream_executor/tpu/proto_helper.h"
+#include "tensorflow/stream_executor/tpu/c_api_decl.h"
 
 extern "C" {
 
diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc
new file mode 100644
index 0000000..a45a4d6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc
@@ -0,0 +1,70 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
+
+#include "tensorflow/core/tpu/kernels/tpu_util.h"
+
+namespace tensorflow {
+
+const char kTpuPodStateResourceName[] = "tpu_pod_state";
+
+TpuPodState::TpuPodState(
+    int service_port, std::unique_ptr<TpuCompilationCacheService> cache_service)
+    : cache_service_(std::move(cache_service)), service_port_(service_port) {}
+
+TpuPodState::~TpuPodState() {
+  if (cache_service_) {
+    VLOG(1) << "Shutting down Compilation Cache Service.";
+    if (cache_service_->Shutdown(20)) {
+      if (service_port_ >= 0) {
+        tpu::RecycleUnusedPort(service_port_);
+      }
+    } else {
+      LOG(ERROR)
+          << "Failed to shutdown Compilation Cache Service within timeout.";
+    }
+  }
+  VLOG(1) << "Shutting down Compilation Cache Service done.";
+}
+
+string TpuPodState::DebugString() const {
+  return "Wrapper for distributed TPU state";
+}
+
+Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state) {
+  if (!rmgr) {
+    return errors::Internal("No resource manager.");
+  }
+  if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName,
+                    pod_state)
+           .ok()) {
+    return errors::FailedPrecondition(
+        "The TPU system has not been initialized.");
+  }
+  return Status::OK();
+}
+
+bool HasTPUPodState(const ResourceMgr* rmgr) {
+  TpuPodState* pod_state;
+  if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName,
+                    &pod_state)
+           .ok()) {
+    return false;
+  }
+  pod_state->Unref();
+  return true;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.h b/tensorflow/core/tpu/kernels/tpu_pod_state.h
new file mode 100644
index 0000000..9f37e28
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_pod_state.h
@@ -0,0 +1,54 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
+
+#include "grpcpp/server_builder.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h"
+
+namespace tensorflow {
+
+// Name of tpu pod state.
+ABSL_CONST_INIT extern const char kTpuPodStateResourceName[];
+
+// Wrapper to hold centralized state for the distributed TPU in the TPU_SYSTEM
+// device's resource manager.
+class TpuPodState : public ResourceBase {
+ public:
+  // The port number given by isa_cache_port will be freed with
+  // RecycleUnusedPort in the destructor if it is non-negative.
+  TpuPodState(int service_port,
+              std::unique_ptr<TpuCompilationCacheService> cache_service);
+
+  ~TpuPodState() override;
+
+  string DebugString() const override;
+
+ private:
+  std::unique_ptr<TpuCompilationCacheService> cache_service_;
+  int service_port_;
+};
+
+// Returns the TPU pod state or an error.
+Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state);
+
+// Checks whether the TPU POD state configuration is present within the resource
+// manager.
+bool HasTPUPodState(const ResourceMgr* rmgr);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_util.cc b/tensorflow/core/tpu/kernels/tpu_util.cc
index 60f8fe0..837c23c 100644
--- a/tensorflow/core/tpu/kernels/tpu_util.cc
+++ b/tensorflow/core/tpu/kernels/tpu_util.cc
@@ -16,6 +16,7 @@
 
 #include "absl/strings/str_split.h"
 #include "tensorflow/core/platform/random.h"
+#include "tensorflow/core/tpu/tpu_api.h"
 
 namespace tensorflow {
 namespace tpu {
@@ -95,5 +96,9 @@
   }
   return Status::OK();
 }
+
+void RecycleUnusedPort(int port) {
+  UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(port);
+}
 }  // namespace tpu
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h
index 579fbdf..834db31 100644
--- a/tensorflow/core/tpu/kernels/tpu_util.h
+++ b/tensorflow/core/tpu/kernels/tpu_util.h
@@ -54,6 +54,11 @@
                                    std::vector<TensorShape>* shapes);
 Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
                                    std::vector<TensorShape>* shapes);
+
+// We only recycle ports which were given to us by the portserver. For ports
+// we obtained through local trial-and-error, there is no reason to expect the
+// port to remain available after it is unbound.
+void RecycleUnusedPort(int port);
 }  // namespace tpu
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/tpu/kernels/tpu_util_c_api.h b/tensorflow/core/tpu/kernels/tpu_util_c_api.h
index ddc7a84..04b65e2 100644
--- a/tensorflow/core/tpu/kernels/tpu_util_c_api.h
+++ b/tensorflow/core/tpu/kernels/tpu_util_c_api.h
@@ -56,6 +56,9 @@
 TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
     const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
 
+// Recycle unused service port.
+TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
+
 // Creates a unique compilation cache `key` used for `put` and `get` operations.
 // Returned buffers are heap-allocated and must be owned.
 TFTPU_CAPI_EXPORT CompilationCacheKeyResult
@@ -79,6 +82,7 @@
   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
   TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
+  TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
diff --git a/tensorflow/core/tpu/kernels/transfer_ops.cc b/tensorflow/core/tpu/kernels/transfer_ops.cc
index 40b85e2..a5cdfd4 100644
--- a/tensorflow/core/tpu/kernels/transfer_ops.cc
+++ b/tensorflow/core/tpu/kernels/transfer_ops.cc
@@ -69,7 +69,8 @@
 }
 
 Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
-  auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
+  auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
+      /*initialize_platform=*/false);
 
   int real_device_ordinal = device_ordinal_;
   if (real_device_ordinal < 0) {
diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
index f7c33e5..fc15d71 100644
--- a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
+++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc
@@ -116,12 +116,44 @@
                       indices_shape.dim_size(d)));
     }
     xla::XlaBuilder* builder = ctx->builder();
+    // data shape = [indices_shape, segment_shape]
+    // buffer shape = [num_segment, segment_shape]
+    // We now create the buffer shape by reverse enginerring data shape into
+    // indices shape and segment shape.
     TensorShape buffer_shape = data_shape;
     buffer_shape.RemoveDimRange(0, indices_shape.dims());
     buffer_shape.InsertDim(0, num_segments);
+
     auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_),
                                  buffer_shape.dim_sizes());
 
+    // Build dynamic dim sizes for buffer, as well as whether each dimension
+    // size is dynamic or static. We build two parts: num_sgement part and
+    // segment_shape part.
+    std::vector<xla::XlaOp> buffer_dims;
+    std::vector<bool> buffer_dims_are_dynamic;
+    // Build the "num_segment" part.
+    bool num_segments_is_dynamic;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic));
+
+    buffer_dims.insert(buffer_dims.begin(), ctx->Input(2));
+    buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(),
+                                   num_segments_is_dynamic);
+    // Build the segment shape part.
+    for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) {
+      buffer_dims.push_back(xla::GetDimensionSize(data, i));
+      buffer_dims_are_dynamic.push_back(
+          ctx->InputXlaShape(0)->is_dynamic_dimension(i));
+    }
+
+    for (int64 i = 0; i < buffer_dims.size(); ++i) {
+      if (buffer_dims_are_dynamic[i]) {
+        // For each dynamic dimension, call set-dimension-size on it.
+        buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i);
+      }
+    }
+
     auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
       return a + b;
     };
diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
index 2f11e06..4dc0977 100644
--- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
+++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc
@@ -72,7 +72,7 @@
 }
 
 bool FindAndLoadTpuLibrary() {
-  void* library = dlopen("libtftpu.so", RTLD_NOW);
+  void* library = dlopen("libtpu.so", RTLD_NOW);
   if (library) {
     InitializeTpuLibrary(library);
   }
diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h
index 55de89d..08417db 100644
--- a/tensorflow/core/tpu/tpu_config_c_api.h
+++ b/tensorflow/core/tpu/tpu_config_c_api.h
@@ -26,9 +26,6 @@
 
 namespace tensorflow {
 class TpuMeshCommonState;
-namespace tpu {
-class TpuMeshStateInterface;
-}  // namespace tpu
 }  // namespace tensorflow
 
 extern "C" {
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index 16494d0..6a1432e 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -1,4 +1,8 @@
+#if defined(PLATFORM_GOOGLE)
 #include "third_party/tensorflow/core/tpu/tpu_executor_init_fns.inc"
+#else
+#include "tensorflow/core/tpu/tpu_executor_init_fns.inc"
+#endif
 
 namespace {
 
@@ -88,6 +92,7 @@
   auto* util_fn = tensorflow::tpu::UtilApiFn();
 
   TFTPU_SET_FN(util_fn, TpuTopology_AvailableCoreCount);
+  TFTPU_SET_FN(util_fn, TpuNetUtil_RecycleUnusedPort);
   TFTPU_SET_FN(util_fn, TpuCompile_IsTpuCompilationEnabled);
   TFTPU_SET_FN(util_fn, TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
   TFTPU_SET_FN(util_fn, TpuCompile_CreateCompilationCacheKey);
diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD
index 7b2c5c4..13e31e9 100644
--- a/tensorflow/core/util/BUILD
+++ b/tensorflow/core/util/BUILD
@@ -17,6 +17,7 @@
     "tf_kernel_library",
 )
 load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 load(
     "//third_party/mkl:build_defs.bzl",
     "mkl_deps",
@@ -364,6 +365,7 @@
 tf_version_info_genrule(
     name = "version_info_gen",
     out = "version_info.cc",
+    compatible_with = get_compatible_with_portable(),
 )
 
 cc_library(
@@ -507,6 +509,7 @@
     name = "version_info",
     srcs = ["version_info.cc"],
     hdrs = ["//tensorflow/core/public:version.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = tf_copts(),
     alwayslink = if_static(0, 1),
 )
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index b148ffa..03acb98 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -167,11 +167,14 @@
   }
 
   // Helper methods
-  tstring& construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
-    return bytes_list->construct_at_end();
+  tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
+    if (bytes_list->EndDistance() <= 0) {
+      return nullptr;
+    }
+    return &bytes_list->construct_at_end();
   }
-  tstring& construct_at_end(SmallVector<tstring>* bytes_list) {
-    return bytes_list->emplace_back();
+  tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
+    return &bytes_list->emplace_back();
   }
 
   template <typename Result>
@@ -192,9 +195,10 @@
       // parse string
       uint32 bytes_length;
       if (!stream.ReadVarint32(&bytes_length)) return false;
-      tstring& bytes = construct_at_end(bytes_list);
-      bytes.resize_uninitialized(bytes_length);
-      if (!stream.ReadRaw(bytes.data(), bytes_length)) return false;
+      tstring* bytes = construct_at_end(bytes_list);
+      if (bytes == nullptr) return false;
+      bytes->resize_uninitialized(bytes_length);
+      if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
     }
     stream.PopLimit(limit);
     return true;
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index 0df810a..1cf9a8c 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -59,6 +59,11 @@
   // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
   // it will be 1. A shrunk dimension is skipped.
   gtl::InlinedVector<int32, 4> final_shape_gather_indices;
+  // This vector has the same size as final_shape_gather_indices, but it
+  // remembers the sparse index that a dimension comes from, instead of dense
+  // index. A -1 in this vector means there the index is not from the sparse
+  // input.
+  gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
   // The dense indexed shrink mask is which processing dimensions
   // should be shrunk. For example, if foo.shape = (10,10,10,10)
   // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
@@ -108,9 +113,11 @@
           dense->begin_mask |= (1 << full_index);
           dense->end_mask |= (1 << full_index);
           dense->final_shape_gather_indices.push_back(full_index);
+          dense->final_shape_gather_indices_sparse.push_back(-1);
         }
       } else if ((1 << i) & sparse.new_axis_mask) {
         dense->final_shape_gather_indices.push_back(kNewAxis);
+        dense->final_shape_gather_indices_sparse.push_back(-1);
       } else {
         if (full_index == dense->begin.size()) {
           return errors::InvalidArgument("Index out of range using input dim ",
@@ -138,9 +145,13 @@
         // axis (now in dense form) so we can ignore dense->end below.
         if (sparse.shrink_axis_mask & (1 << i)) {
           dense->final_shape_gather_indices.push_back(kShrinkAxis);
+          dense->final_shape_gather_indices_sparse.push_back(-1);
           dense->shrink_axis_mask |= (1 << full_index);
         } else {
           dense->final_shape_gather_indices.push_back(full_index);
+          // Remember that where in the sparse shape the dense dim comes
+          // from.
+          dense->final_shape_gather_indices_sparse.push_back(i);
         }
         full_index++;
       }
@@ -157,7 +168,9 @@
     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
-    gtl::InlinedVector<int64, 4>* strides) {
+    gtl::InlinedVector<int64, 4>* strides,
+    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
+    gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
   const bool begin_is_wrong =
       begin_tensor != nullptr &&
       !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
@@ -362,11 +375,34 @@
   // slices like foo[3,...] will reduce dimension by 1.
   // This cannot be done earlier, because it depends on Step 3.
   final_shape->Clear();
-  for (auto gather_index : dense_spec.final_shape_gather_indices) {
+  if (output_to_sparse_mapping != nullptr) {
+    output_to_sparse_mapping->clear();
+  }
+
+  if (output_to_processing_mapping != nullptr) {
+    output_to_processing_mapping->clear();
+  }
+  for (int64 dense_dim = 0;
+       dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
+    int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
+    int64 sparse_index =
+        dense_spec.final_shape_gather_indices_sparse[dense_dim];
     if (gather_index >= 0) {
       final_shape->AddDim(processing_shape->dim_size(gather_index));
+      if (output_to_sparse_mapping != nullptr) {
+        output_to_sparse_mapping->push_back(sparse_index);
+      }
+      if (output_to_processing_mapping != nullptr) {
+        output_to_processing_mapping->push_back(gather_index);
+      }
     } else if (gather_index == kNewAxis) {
       final_shape->AddDim(1);
+      if (output_to_sparse_mapping != nullptr) {
+        output_to_sparse_mapping->push_back(-1);
+      }
+      if (output_to_processing_mapping != nullptr) {
+        output_to_processing_mapping->push_back(-1);
+      }
     }
   }
   return Status::OK();
@@ -379,14 +415,17 @@
     int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
-    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
+    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
+    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
+    gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
   // Validate with PartialTensorShape output
   PartialTensorShape partial_processing_shape, partial_final_shape;
   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
       begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
       end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
       &partial_processing_shape, &partial_final_shape, is_identity,
-      is_simple_slice, slice_dim0, begin, end, strides));
+      is_simple_slice, slice_dim0, begin, end, strides,
+      output_to_sparse_mapping, output_to_processing_mapping));
 
   // Verify that the output shapes are fully known
   if (!partial_processing_shape.AsTensorShape(processing_shape) ||
diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h
index 25ecccd..9e49477 100644
--- a/tensorflow/core/util/strided_slice_op.h
+++ b/tensorflow/core/util/strided_slice_op.h
@@ -40,6 +40,17 @@
 // some dimensions of <processing_shape> and/or <final_shape> may be unknown
 // (-1). Any validation that can be done without complete information is
 // performed.
+//
+// This function changes the orders of dimensions, output_to_sparse_mapping and
+// output_to_processing_mapping are used to track the order change.
+//
+// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
+// index in the begin_tensor. If
+// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
+// sparse_mapping.
+//
+// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
+// processing_shape.
 Status ValidateStridedSliceOp(
     const Tensor* begin_tensor, const Tensor* end_tensor,
     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
@@ -48,7 +59,9 @@
     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
-    gtl::InlinedVector<int64, 4>* strides);
+    gtl::InlinedVector<int64, 4>* strides,
+    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
+    gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
 
 // Same as above, but the outputs are TensorShape, not PartialTensorShape
 Status ValidateStridedSliceOp(
@@ -58,7 +71,9 @@
     int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
-    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides);
+    gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
+    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
+    gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/examples/tutorials/mnist/input_data.py b/tensorflow/examples/tutorials/mnist/input_data.py
index c203c7b..a70787d 100644
--- a/tensorflow/examples/tutorials/mnist/input_data.py
+++ b/tensorflow/examples/tutorials/mnist/input_data.py
@@ -138,7 +138,7 @@
     Args:
       images: The images
       labels: The labels
-      fake_data: Ignore inages and labels, use fake data.
+      fake_data: Ignore images and labels, use fake data.
       one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
         False).
       dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
@@ -330,4 +330,3 @@
   test = _DataSet(test_images, test_labels, **options)
 
   return _Datasets(train=train, validation=validation, test=test)
-
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index ac28c3a..60de1e1 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -495,3 +495,34 @@
 	}
 	return nil
 }
+
+type LibraryHandler struct {
+	cptr *C.TF_Library
+}
+
+// Load library content into current context, useful to load ops implementation into non-monolitic TF build. Returns LibraryHandler or nil and error
+func LoadLibrary(path string) (*LibraryHandler, error) {
+	status := newStatus()
+
+	cpath := C.CString(path)
+	defer C.free(unsafe.Pointer(cpath))
+	cptr := C.TF_LoadLibrary(cpath, status.c)
+	if cptr == nil || status.Code() != C.TF_OK {
+		return nil, fmt.Errorf("could not load library %s: code: %d, error: %s", path, status.Code(), status.String())
+	}
+
+	lh := &LibraryHandler{
+		cptr: cptr,
+	}
+
+	runtime.SetFinalizer(lh, (*LibraryHandler).free)
+	return lh, nil
+}
+
+func (lh *LibraryHandler) free() {
+	if lh == nil || lh.cptr == nil {
+		return
+	}
+
+	C.TF_DeleteLibraryHandle(lh.cptr)
+}
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 2a4b406..6b7f011 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -7748,98 +7748,6 @@
 	return op.Output(0)
 }
 
-// Does nothing. Serves as a control trigger for scheduling.
-//
-// Only useful as a placeholder for control edges.
-//
-// Returns the created operation.
-func ControlTrigger(scope *Scope) (o *tf.Operation) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "ControlTrigger",
-	}
-	return scope.AddOperation(opspec)
-}
-
-// Interleave the values from the `data` tensors into a single tensor.
-//
-// Builds a merged tensor such that
-//
-// ```python
-//     merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]
-// ```
-//
-// For example, if each `indices[m]` is scalar or vector, we have
-//
-// ```python
-//     # Scalar indices:
-//     merged[indices[m], ...] = data[m][...]
-//
-//     # Vector indices:
-//     merged[indices[m][i], ...] = data[m][i, ...]
-// ```
-//
-// Each `data[i].shape` must start with the corresponding `indices[i].shape`,
-// and the rest of `data[i].shape` must be constant w.r.t. `i`.  That is, we
-// must have `data[i].shape = indices[i].shape + constant`.  In terms of this
-// `constant`, the output shape is
-//
-//     merged.shape = [max(indices)] + constant
-//
-// Values may be merged in parallel, so if an index appears in both `indices[m][i]`
-// and `indices[n][j]`, the result may be invalid. This differs from the normal
-// DynamicStitch operator that defines the behavior in that case.
-//
-// For example:
-//
-// ```python
-//     indices[0] = 6
-//     indices[1] = [4, 1]
-//     indices[2] = [[5, 2], [0, 3]]
-//     data[0] = [61, 62]
-//     data[1] = [[41, 42], [11, 12]]
-//     data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]
-//     merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],
-//               [51, 52], [61, 62]]
-// ```
-//
-// This method can be used to merge partitions created by `dynamic_partition`
-// as illustrated on the following example:
-//
-// ```python
-//     # Apply function (increments x_i) on elements for which a certain condition
-//     # apply (x_i != -1 in this example).
-//     x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])
-//     condition_mask=tf.not_equal(x,tf.constant(-1.))
-//     partitioned_data = tf.dynamic_partition(
-//         x, tf.cast(condition_mask, tf.int32) , 2)
-//     partitioned_data[1] = partitioned_data[1] + 1.0
-//     condition_indices = tf.dynamic_partition(
-//         tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)
-//     x = tf.dynamic_stitch(condition_indices, partitioned_data)
-//     # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
-//     # unchanged.
-// ```
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/DynamicStitch.png" alt>
-// </div>
-func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "ParallelDynamicStitch",
-		Input: []tf.Input{
-			tf.OutputList(indices), tf.OutputList(data),
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Partitions `data` into `num_partitions` tensors using indices from `partitions`.
 //
 // For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`
@@ -15011,9 +14919,9 @@
 // The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions
 // form square matrices. The outputs are two tensors containing the signs and
 // absolute values of the log determinants for all N input submatrices
-// `[..., :, :]` such that the determinant = sign*exp(log_abs_determinant).
-// The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU
-// is the LU decomposition of the input and P is the corresponding
+// `[..., :, :]` such that `determinant = sign*exp(log_abs_determinant)`.
+// The `log_abs_determinant` is computed as `det(P)*sum(log(diag(LU)))` where `LU`
+// is the `LU` decomposition of the input and `P` is the corresponding
 // permutation matrix.
 //
 // Arguments:
@@ -15307,6 +15215,118 @@
 	return op.Output(0)
 }
 
+// Does nothing. Serves as a control trigger for scheduling.
+//
+// Only useful as a placeholder for control edges.
+//
+// Returns the created operation.
+func ControlTrigger(scope *Scope) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "ControlTrigger",
+	}
+	return scope.AddOperation(opspec)
+}
+
+// Interleave the values from the `data` tensors into a single tensor.
+//
+// Builds a merged tensor such that
+//
+// ```python
+//     merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]
+// ```
+//
+// For example, if each `indices[m]` is scalar or vector, we have
+//
+// ```python
+//     # Scalar indices:
+//     merged[indices[m], ...] = data[m][...]
+//
+//     # Vector indices:
+//     merged[indices[m][i], ...] = data[m][i, ...]
+// ```
+//
+// Each `data[i].shape` must start with the corresponding `indices[i].shape`,
+// and the rest of `data[i].shape` must be constant w.r.t. `i`.  That is, we
+// must have `data[i].shape = indices[i].shape + constant`.  In terms of this
+// `constant`, the output shape is
+//
+//     merged.shape = [max(indices)] + constant
+//
+// Values may be merged in parallel, so if an index appears in both `indices[m][i]`
+// and `indices[n][j]`, the result may be invalid. This differs from the normal
+// DynamicStitch operator that defines the behavior in that case.
+//
+// For example:
+//
+// ```python
+//     indices[0] = 6
+//     indices[1] = [4, 1]
+//     indices[2] = [[5, 2], [0, 3]]
+//     data[0] = [61, 62]
+//     data[1] = [[41, 42], [11, 12]]
+//     data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]
+//     merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],
+//               [51, 52], [61, 62]]
+// ```
+//
+// This method can be used to merge partitions created by `dynamic_partition`
+// as illustrated on the following example:
+//
+// ```python
+//     # Apply function (increments x_i) on elements for which a certain condition
+//     # apply (x_i != -1 in this example).
+//     x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])
+//     condition_mask=tf.not_equal(x,tf.constant(-1.))
+//     partitioned_data = tf.dynamic_partition(
+//         x, tf.cast(condition_mask, tf.int32) , 2)
+//     partitioned_data[1] = partitioned_data[1] + 1.0
+//     condition_indices = tf.dynamic_partition(
+//         tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)
+//     x = tf.dynamic_stitch(condition_indices, partitioned_data)
+//     # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
+//     # unchanged.
+// ```
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/DynamicStitch.png" alt>
+// </div>
+func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "ParallelDynamicStitch",
+		Input: []tf.Input{
+			tf.OutputList(indices), tf.OutputList(data),
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Returns a Tensor stack of all keys in a tensor map.
+//
+// input_handle: the input map
+// keys: the returned Tensor of all keys in the map
+func TensorMapStackKeys(scope *Scope, input_handle tf.Output, key_dtype tf.DataType) (keys tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"key_dtype": key_dtype}
+	opspec := tf.OpSpec{
+		Type: "TensorMapStackKeys",
+		Input: []tf.Input{
+			input_handle,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Returns whether the given key exists in the map.
 //
 // input_handle: the input map
@@ -15326,27 +15346,6 @@
 	return op.Output(0)
 }
 
-// Returns the value from a given key in a tensor map.
-//
-// input_handle: the input map
-// key: the key to be looked up
-// value: the value found from the given key
-func TensorMapLookup(scope *Scope, input_handle tf.Output, key tf.Output, value_dtype tf.DataType) (value tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"value_dtype": value_dtype}
-	opspec := tf.OpSpec{
-		Type: "TensorMapLookup",
-		Input: []tf.Input{
-			input_handle, key,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Inverse 3D fast Fourier transform.
 //
 // Computes the inverse 3-dimensional discrete Fourier transform over the
@@ -15395,30 +15394,22 @@
 	return op.Output(0)
 }
 
-// Merges summaries.
+// Returns the value from a given key in a tensor map.
 //
-// This op creates a
-// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
-// protocol buffer that contains the union of all the values in the input
-// summaries.
-//
-// When the Op is run, it reports an `InvalidArgument` error if multiple values
-// in the summaries to merge use the same tag.
-//
-// Arguments:
-//	inputs: Can be of any shape.  Each must contain serialized `Summary` protocol
-// buffers.
-//
-// Returns Scalar. Serialized `Summary` protocol buffer.
-func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) {
+// input_handle: the input map
+// key: the key to be looked up
+// value: the value found from the given key
+func TensorMapLookup(scope *Scope, input_handle tf.Output, key tf.Output, value_dtype tf.DataType) (value tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
+	attrs := map[string]interface{}{"value_dtype": value_dtype}
 	opspec := tf.OpSpec{
-		Type: "MergeSummary",
+		Type: "TensorMapLookup",
 		Input: []tf.Input{
-			tf.OutputList(inputs),
+			input_handle, key,
 		},
+		Attrs: attrs,
 	}
 	op := scope.AddOperation(opspec)
 	return op.Output(0)
@@ -20366,8 +20357,7 @@
 // input_handle: the original map
 // output_handle: the map with value from given key removed
 // key: the key of the value to be erased
-// value: the value that was erased
-func TensorMapErase(scope *Scope, input_handle tf.Output, key tf.Output, value_dtype tf.DataType) (output_handle tf.Output, value tf.Output) {
+func TensorMapErase(scope *Scope, input_handle tf.Output, key tf.Output, value_dtype tf.DataType) (output_handle tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
@@ -20380,7 +20370,7 @@
 		Attrs: attrs,
 	}
 	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1)
+	return op.Output(0)
 }
 
 // Shuffle dimensions of x according to a permutation.
@@ -21542,6 +21532,11 @@
 //
 // *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
 // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+//
+// Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor.
+//
+// Both input and output have a range `(-inf, inf)`.
+//
 func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
 	if scope.Err() != nil {
 		return
@@ -21619,6 +21614,12 @@
 }
 
 // Computes acos of x element-wise.
+//
+//
+//   Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`.
+//
+//   Input range is `[-1, 1]` and the output has a range of `[0, pi]`.
+//
 func Acos(scope *Scope, x tf.Output) (y tf.Output) {
 	if scope.Err() != nil {
 		return
@@ -24329,9 +24330,28 @@
 	}
 }
 
-// Returns the permuted vector/tensor in the destination data format given the
+// Permute input tensor from `src_format` to `dst_format`.
 //
-// one in the source data format.
+// Input tensor must be a vector of size 4, or a 4x2 tensor.
+//
+// For example, with `src_format` of `NHWC`, `dst_format` of `NCHW`, and inputs:
+// ```
+// [1, 2, 3, 4]
+// ```
+// and
+// ```
+// [[1, 2, 3, 4],
+//  [5, 6, 7, 8]]
+// ```
+// , the outputs will be (respectively):
+// ```
+// [1, 4, 2, 3]
+// ```
+// and
+// ```
+// [[1, 4, 2, 3],
+//  [5, 8, 6, 7]]
+// ```
 //
 // Arguments:
 //	x: Vector of size 4 or Tensor of shape (4, 2) in source data format.
@@ -42481,7 +42501,7 @@
 //	mom: Should be from a Variable().
 //	lr: Scaling factor. Must be a scalar.
 //	rho: Decay rate. Must be a scalar.
-//
+//	momentum: Momentum Scale. Must be a scalar.
 //	epsilon: Ridge term. Must be a scalar.
 //	grad: The gradient.
 //
@@ -42768,6 +42788,35 @@
 	return op.Output(0), op.Output(1)
 }
 
+// Merges summaries.
+//
+// This op creates a
+// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
+// protocol buffer that contains the union of all the values in the input
+// summaries.
+//
+// When the Op is run, it reports an `InvalidArgument` error if multiple values
+// in the summaries to merge use the same tag.
+//
+// Arguments:
+//	inputs: Can be of any shape.  Each must contain serialized `Summary` protocol
+// buffers.
+//
+// Returns Scalar. Serialized `Summary` protocol buffer.
+func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "MergeSummary",
+		Input: []tf.Input{
+			tf.OutputList(inputs),
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // MaxPool3DGradAttr is an optional argument to MaxPool3DGrad.
 type MaxPool3DGradAttr func(optionalAttr)
 
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index 9221d35..d9036ce 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -83,7 +83,7 @@
 		return nil, err
 	}
 	nflattened := numElements(shape)
-	nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
+	nbytes := TypeOf(dataType, nil).Size() * uintptr(nflattened)
 	if dataType == String {
 		nbytes = uintptr(nflattened) * C.sizeof_TF_TString
 	}
@@ -168,7 +168,7 @@
 	if err := isTensorSerializable(dataType); err != nil {
 		return nil, err
 	}
-	nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
+	nbytes := TypeOf(dataType, nil).Size() * uintptr(numElements(shape))
 	var shapePtr *C.int64_t
 	if len(shape) > 0 {
 		shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
@@ -207,6 +207,28 @@
 // Shape returns the shape of the Tensor.
 func (t *Tensor) Shape() []int64 { return t.shape }
 
+// Reshape  updates tensor's shape in place if this is possible or returns an error otherwise.
+func (t *Tensor) Reshape(new_shape []int64) error {
+	old_shape_size := numElements(t.shape)
+	new_shape_size := numElements(new_shape)
+
+	if old_shape_size != new_shape_size {
+		return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size)
+	}
+
+	if len(new_shape) == 0 {
+		return nil
+	}
+
+	var shapePtr *C.int64_t
+	shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0]))
+
+	status := newStatus()
+	C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c)
+
+	return status.Err()
+}
+
 // Value converts the Tensor to a Go value. For now, not all Tensor types are
 // supported, and this function may panic if it encounters an unsupported
 // DataType.
@@ -407,8 +429,8 @@
 	panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
 }
 
-// typeOf converts from a DataType and Shape to the equivalent Go type.
-func typeOf(dt DataType, shape []int64) reflect.Type {
+// TypeOf converts from a DataType and Shape to the equivalent Go type.
+func TypeOf(dt DataType, shape []int64) reflect.Type {
 	ret := typeForDataType(dt)
 	for range shape {
 		ret = reflect.SliceOf(ret)
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index e80e32f..058d2e7 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -1,6 +1,7 @@
 load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test")
 load("//tensorflow/lite:build_def.bzl", "if_tflite_experimental_runtime", "tflite_cc_shared_object", "tflite_copts", "tflite_experimental_runtime_linkopts")
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -16,13 +17,6 @@
 ]))
 
 config_setting(
-    name = "enable_default_profiler",
-    values = {
-        "copt": "-DTFLITE_ENABLE_DEFAULT_PROFILER",
-    },
-)
-
-config_setting(
     name = "gemmlowp_profiling",
     values = {
         "copt": "-DGEMMLOWP_PROFILING",
@@ -90,6 +84,7 @@
 cc_library(
     name = "version",
     hdrs = ["version.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     # Note that we only use the header defines from :version_lib.
     deps = ["//tensorflow/core:version_lib"],
@@ -107,6 +102,7 @@
     name = "arena_planner",
     srcs = ["arena_planner.cc"],
     hdrs = ["arena_planner.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     deps = [
         ":graph_info",
@@ -145,6 +141,7 @@
     name = "external_cpu_backend_context",
     srcs = ["external_cpu_backend_context.cc"],
     hdrs = ["external_cpu_backend_context.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     deps = [
         "//tensorflow/lite/c:common",
@@ -154,6 +151,7 @@
 cc_library(
     name = "graph_info",
     hdrs = ["graph_info.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     deps = ["//tensorflow/lite/c:common"],
 )
@@ -161,6 +159,7 @@
 cc_library(
     name = "memory_planner",
     hdrs = ["memory_planner.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     deps = ["//tensorflow/lite/c:common"],
 )
@@ -169,6 +168,7 @@
     name = "simple_memory_arena",
     srcs = ["simple_memory_arena.cc"],
     hdrs = ["simple_memory_arena.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     deps = ["//tensorflow/lite/c:common"],
 )
@@ -188,6 +188,7 @@
         "builtin_ops.h",
         "context_util.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     deps = ["//tensorflow/lite/c:common"],
 )
 
@@ -198,6 +199,7 @@
     hdrs = [
         "string_type.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
 )
 
@@ -219,6 +221,7 @@
     hdrs = [
         "allocation.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     deps = [
         ":string",
@@ -240,6 +243,7 @@
         "stderr_reporter.cc",
     ],
     hdrs = FRAMEWORK_LIB_HDRS,
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
     visibility = [
         "//tensorflow/lite:__subpackages__",
@@ -264,13 +268,9 @@
         "//tensorflow/lite/experimental/resource",
         "//tensorflow/lite/kernels/internal:compatibility",
         "//tensorflow/lite/nnapi:nnapi_implementation",
+        "//tensorflow/lite/profiling:platform_profiler",
         "//tensorflow/lite/schema:schema_fbs",
-    ] + select({
-        ":enable_default_profiler": [
-            "//tensorflow/lite/profiling:platform_profiler",
-        ],
-        "//conditions:default": [],
-    }),
+    ],
     alwayslink = 1,
 )
 
@@ -280,6 +280,7 @@
     srcs = [
     ],
     hdrs = FRAMEWORK_LIB_HDRS,
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
     defines = if_tflite_experimental_runtime(
         if_eager = ["TFLITE_EXPERIMENTAL_RUNTIME_EAGER"],
@@ -312,6 +313,7 @@
     name = "string_util",
     srcs = ["string_util.cc"],
     hdrs = ["string_util.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS,
     deps = [
         ":string",
@@ -356,6 +358,7 @@
 
 cc_library(
     name = "tflite_with_xnnpack_default",
+    compatible_with = get_compatible_with_portable(),
     visibility = ["//visibility:private"],
     # TODO(b/151246885): put ":tflite_with_xnnpack_enabled" to macos/windows
     # once we have a good testing coverage on these two platforms.
@@ -373,6 +376,7 @@
         "core/macros.h",
         "tflite_with_xnnpack_optional.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
     deps = [
         "//tensorflow/lite/c:common",
@@ -568,6 +572,7 @@
     name = "util",
     srcs = ["util.cc"],
     hdrs = ["util.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
     deps = [
         ":kernel_api",
@@ -611,6 +616,7 @@
         ],
     }),
     hdrs = ["minimal_logging.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
     linkopts = select({
         "//tensorflow:android": ["-llog"],
@@ -631,6 +637,7 @@
             "type_to_tflitetype.h",
         ],
     }),
+    compatible_with = get_compatible_with_portable(),
     deps = ["//tensorflow/lite/c:common"],
 )
 
@@ -660,6 +667,7 @@
 cc_library(
     name = "shared_library",
     hdrs = ["shared_library.h"],
+    compatible_with = get_compatible_with_portable(),
     linkopts = if_not_windows(["-ldl"]),
 )
 
diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt
new file mode 100644
index 0000000..cfd8ebf
--- /dev/null
+++ b/tensorflow/lite/CMakeLists.txt
@@ -0,0 +1,341 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Builds the Tensorflow Lite runtime.
+#
+# WARNING: This is an experimental that is subject to change.
+# This has only been tested on Windows, Linux and macOS.
+#
+# The following are not currently supported:
+# - GPU acceleration
+# - Android
+# - iOS
+# - Micro backend
+# - Tests
+# - Many features in experimental
+# - Host Tools (i.e conversion / analysis tools etc.)
+
+cmake_minimum_required(VERSION 3.16)
+# Double colon in target name means ALIAS or IMPORTED target.
+cmake_policy(SET CMP0028 NEW)
+# Enable MACOSX_RPATH (@rpath) for built dynamic libraries.
+cmake_policy(SET CMP0042 NEW)
+project(tensorflow-lite C CXX)
+set(TENSORFLOW_SOURCE_DIR "" CACHE PATH
+  "Directory that contains the TensorFlow project"
+)
+if(NOT TENSORFLOW_SOURCE_DIR)
+  set(TENSORFLOW_SOURCE_DIR "${CMAKE_SOURCE_DIR}/../../")
+endif()
+set(TF_SOURCE_DIR "${TENSORFLOW_SOURCE_DIR}/tensorflow")
+set(TFLITE_SOURCE_DIR "${CMAKE_SOURCE_DIR}")
+set(CMAKE_MODULE_PATH "${TFLITE_SOURCE_DIR}/tools/cmake/modules" ${CMAKE_MODULE_PATH})
+set(CMAKE_PREFIX_PATH "${TFLITE_SOURCE_DIR}/tools/cmake/modules" ${CMAKE_PREFIX_PATH})
+
+option(TFLITE_ENABLE_RUY "Enable experimental RUY integration" OFF)
+option(TFLITE_ENABLE_RESOURCE "Enable experimental support for resources" ON)
+option(TFLITE_ENABLE_NNAPI "Enable NNAPI (Android only)." ON)
+option(TFLITE_ENABLE_MMAP "Enable MMAP (unsupported on Windows)" ON)
+option(TFLITE_ENABLE_GPU "Enable GPU (not supported)" OFF)
+# This must be enabled when converting from TF models with SELECT_TF_OPS
+# enabled.
+# https://www.tensorflow.org/lite/guide/ops_select#converting_the_model
+# This is currently not supported.
+option(TFLITE_ENABLE_FLEX "Enable SELECT_TF_OPS" OFF) # TODO: Add support
+option(TFLITE_ENABLE_XNNPACK "Enable XNNPACK backend" OFF) # TODO: Add XNNPACK
+option(TFLITE_ENABLE_PROFILING "Enable profiling" OFF)
+set(CMAKE_CXX_STANDARD 14)  # Some components require C++14.
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(_TFLITE_ENABLE_NNAPI "${TFLITE_ENABLE_NNAPI}")
+if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
+  set(_TFLITE_ENABLE_NNAPI OFF)
+endif()
+set(_TFLITE_ENABLE_MMAP "${TFLITE_ENABLE_MMAP}")
+if(${CMAKE_SYSTEM_NAME} MATCHES "Windows")
+  # See https://github.com/tensorflow/tensorflow/blob/\
+  # 2b96f3662bd776e277f86997659e61046b56c315/tensorflow/lite/tools/make/\
+  # Makefile#L157
+  set(_TFLITE_ENABLE_MMAP OFF)
+endif()
+# Simplifies inclusion of non-test sources and headers from a directory.
+# SOURCE_DIR: Directory to search for files.
+# SOURCES_VAR: Variable to append with all matching *.cc and *.h files.
+# [FILTER expression0 .. expressionN]:
+#   Additional regular expressions to filter the set of matching
+#   files. By default, all files ending in "(_test|test_util)\\.(cc|h)" are
+#   removed.
+# [RECURSE]: Whether to recursively search SOURCE_DIR.
+macro(populate_source_vars SOURCE_DIR SOURCES_VAR)
+  cmake_parse_arguments(ARGS "RECURSE" "" "FILTER" ${ARGN})
+  if(ARGS_RECURSE)
+    set(GLOB_OP GLOB_RECURSE)
+  else()
+    set(GLOB_OP GLOB)
+  endif()
+  set(DEFAULT_FILE_FILTER ".*(_test|test_util)\\.(c|cc|h)$")
+  file(${GLOB_OP} FOUND_SOURCES "${SOURCE_DIR}/*.*")
+  list(FILTER FOUND_SOURCES INCLUDE REGEX ".*\\.(c|cc|h)$")
+  list(FILTER FOUND_SOURCES EXCLUDE REGEX "${DEFAULT_FILE_FILTER}")
+  foreach(FILE_FILTER ${ARGS_FILTER})
+    list(FILTER FOUND_SOURCES EXCLUDE REGEX "${FILE_FILTER}")
+  endforeach()
+  list(APPEND ${SOURCES_VAR} ${FOUND_SOURCES})
+endmacro()
+# Simplifies inclusion of non-test sources and headers from a directory
+# relative to TFLITE_SOURCE_DIR. See populate_source_vars() for the
+# description of arguments including and following SOURCES_VAR.
+macro(populate_tflite_source_vars RELATIVE_DIR SOURCES_VAR)
+  populate_source_vars(
+    "${TFLITE_SOURCE_DIR}/${RELATIVE_DIR}" ${SOURCES_VAR} ${ARGN}
+  )
+endmacro()
+# Simplifies inclusion of non-test sources and headers from a directory
+# relative to TF_SOURCE_DIR. See populate_source_vars() for the description of
+# arguments including and following SOURCES_VAR.
+macro(populate_tf_source_vars RELATIVE_DIR SOURCES_VAR)
+  populate_source_vars(
+    "${TF_SOURCE_DIR}/${RELATIVE_DIR}" ${SOURCES_VAR} ${ARGN}
+  )
+endmacro()
+# Find TensorFlow Lite dependencies.
+find_package(absl REQUIRED CONFIG)
+find_package(eigen REQUIRED)
+find_package(farmhash REQUIRED)
+find_package(fft2d REQUIRED)
+find_package(flatbuffers REQUIRED)
+find_package(gemmlowp REQUIRED)
+find_package(neon2sse REQUIRED)
+find_package(ruy REQUIRED)
+# Generate TensorFlow Lite FlatBuffer code.
+# This is not currently neccessary since the generated code is checked into
+# the repository but it would likely be preferable to do this in future.
+# NOTE: This will not work for cross compilation (e.g for iOS, Android etc.)
+# as flatc needs to be compiled with the host toolchain and this currently
+# builds with the target toolchain. Instead this should recursively call
+# cmake with the default host toolchain to build flatc.
+set(TFLITE_FLATBUFFERS_SCHEMAS "${TFLITE_SOURCE_DIR}/schema/schema.fbs")
+set(TFLITE_FLATBUFFERS_GEN_DIR
+  "${CMAKE_BINARY_DIR}/flatbuffers_generated/"
+)
+set(TFLITE_FLATBUFFERS_HDRS "")
+foreach(INPUT_SCHEMA ${TFLITE_FLATBUFFERS_SCHEMAS})
+  file(RELATIVE_PATH FILENAME "${TENSORFLOW_SOURCE_DIR}" "${INPUT_SCHEMA}")
+  get_filename_component(OUTPUT_DIR
+    "${TFLITE_FLATBUFFERS_GEN_DIR}/${FILENAME}" DIRECTORY
+  )
+  get_filename_component(OUTPUT_BASENAME
+    "${FILENAME}" NAME_WE
+  )
+  set(OUTPUT_FILENAME "${OUTPUT_DIR}/${OUTPUT_BASENAME}_generated.h")
+  list(APPEND TFLITE_FLATBUFFERS_HDRS "${OUTPUT_FILENAME}")
+  add_custom_command(
+    OUTPUT "${OUTPUT_FILENAME}"
+    COMMAND flatc
+      --cpp
+      --gen-mutable
+      --gen-object-api
+      --reflect-names
+      -I "${TENSORFLOW_SOURCE_DIR}"
+      -o "${OUTPUT_DIR}"
+      "${INPUT_SCHEMA}"
+    DEPENDS
+      "${INPUT_SCHEMA}")
+endforeach()
+set(TF_TARGET_PRIVATE_OPTIONS "")
+if(CMAKE_CXX_COMPILER_ID MATCHES "Clang$")
+  # TensorFlow uses a heap of deprecated proto fields so surpress these
+  # warnings until they're fixed.
+  list(APPEND TF_TARGET_PRIVATE_OPTIONS "-Wno-deprecated-declarations")
+endif()
+# Additional compiler flags used when compiling TF Lite.
+set(TFLITE_TARGET_PUBLIC_OPTIONS "")
+set(TFLITE_TARGET_PRIVATE_OPTIONS "")
+# Additional library dependencies based upon enabled features.
+set(TFLITE_TARGET_DEPENDENCIES "")
+if(CMAKE_CXX_COMPILER_ID MATCHES "Clang$")
+  # TFLite uses deprecated methods in neon2sse which generates a huge number of
+  # warnings so surpress these until they're fixed.
+  list(APPEND TFLITE_TARGET_PRIVATE_OPTIONS "-Wno-deprecated-declarations")
+endif()
+if(CMAKE_SYSTEM_NAME MATCHES "Windows")
+  # Use NOMINMAX to disable the min / max macros in windows.h as they break
+  # use of std::min std::max.
+  # Use NOGDI to ERROR macro which breaks TensorFlow logging.
+  list(APPEND TFLITE_TARGET_PRIVATE_OPTIONS "-DNOMINMAX" "-DNOGDI")
+endif()
+# Build a list of source files to compile into the TF Lite library.
+populate_tflite_source_vars("." TFLITE_SRCS)
+if(_TFLITE_ENABLE_MMAP)
+  list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation_disabled\\.cc$")
+else()
+  list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation\\.cc$")
+endif()
+if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
+  list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*minimal_logging_android\\.cc$")
+endif()
+if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "iOS")
+  list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*minimal_logging_ios\\.cc$")
+endif()
+populate_tflite_source_vars("core" TFLITE_CORE_SRCS)
+populate_tflite_source_vars("core/api" TFLITE_CORE_API_SRCS)
+populate_tflite_source_vars("c" TFLITE_C_SRCS)
+populate_tflite_source_vars("delegates" TFLITE_DELEGATES_SRCS)
+if(TFLITE_ENABLE_FLEX)
+  message(FATAL_ERROR "TF Lite Flex delegate is currently not supported.")
+  populate_tflite_source_vars("delegates/flex" TFLITE_DELEGATES_FLEX_SRCS)
+  list(APPEND TFLITE_TARGET_DEPENDENCIES
+    absl::inlined_vector
+    absl::optional
+    absl::type_traits
+  )
+endif()
+if(TFLITE_ENABLE_GPU)
+  # Implementation is under delegates/gpu.
+  message(FATAL_ERROR
+    "GPU acceleration is not currently supported in CMake builds"
+  )
+endif()
+if(_TFLITE_ENABLE_NNAPI)
+  populate_tflite_source_vars("delegates/nnapi"
+    TFLITE_DELEGATES_NNAPI_SRCS
+    FILTER "(_test_list|_disabled)\\.(cc|h)$"
+  )
+  populate_tflite_source_vars(
+    "nnapi" TFLITE_NNAPI_SRCS FILTER "(_disabled)\\.(cc|h)$"
+  )
+else()
+  set(TFLITE_DELEGATES_NNAPI_SRCS
+    "${TFLITE_SOURCE_DIR}/delegates/nnapi/nnapi_delegate_disabled.cc"
+  )
+  set(TFLITE_NNAPI_SRCS
+    "${TFLITE_SOURCE_DIR}/nnapi/nnapi_implementation_disabled.cc"
+  )
+endif()
+if(TFLITE_ENABLE_XNNPACK)
+  populate_tflite_source_vars("delegates/xnnpack"
+    TFLITE_DELEGATES_XNNPACK_SRCS
+  )
+endif()
+if (TFLITE_ENABLE_RESOURCE)
+  populate_tflite_source_vars("experimental/resource"
+    TFLITE_EXPERIMENTAL_RESOURCE_SRCS
+  )
+endif()
+populate_tflite_source_vars("experimental/ruy"
+  TFLITE_EXPERIMENTAL_RUY_SRCS
+  FILTER
+  ".*(test(_fast|_slow|_special_specs))\\.(cc|h)$"
+  ".*(benchmark|tune_tool|example)\\.(cc|h)$"
+)
+populate_tflite_source_vars("experimental/ruy/profiler"
+  TFLITE_EXPERIMENTAL_RUY_PROFILER_SRCS
+  FILTER ".*(test|test_instrumented_library)\\.(cc|h)$"
+)
+if(TFLITE_ENABLE_RUY)
+  list(APPEND TFLITE_TARGET_PUBLIC_OPTIONS "-DTFLITE_WITH_RUY")
+endif()
+populate_tflite_source_vars("kernels"
+  TFLITE_KERNEL_SRCS
+  FILTER ".*(_test_util_internal|test_main)\\.(cc|h)"
+)
+populate_tflite_source_vars("kernels/internal" TFLITE_KERNEL_INTERNAL_SRCS)
+populate_tflite_source_vars("kernels/internal/optimized"
+  TFLITE_KERNEL_INTERNAL_OPT_SRCS
+)
+populate_tflite_source_vars("kernels/internal/optimized/integer_ops"
+  TFLITE_KERNEL_INTERNAL_OPT_INTEGER_OPS_SRCS
+)
+populate_tflite_source_vars("kernels/internal/optimized/sparse_ops"
+  TFLITE_KERNEL_INTERNAL_OPT_SPARSE_OPS_SRCS
+)
+populate_tflite_source_vars("kernels/internal/reference"
+  TFLITE_KERNEL_INTERNAL_REF_SRCS
+)
+populate_tflite_source_vars("kernels/internal/reference/integer_ops"
+  TFLITE_KERNEL_INTERNAL_REF_INTEGER_OPS_SRCS
+)
+populate_tflite_source_vars("kernels/internal/reference/sparse_ops"
+  TFLITE_KERNEL_INTERNAL_REF_SPARSE_OPS_SRCS
+)
+if(TFLITE_ENABLE_PROFILING)
+  populate_tflite_source_vars("profiling" TFLITE_KERNEL_PROFILING_SRCS)
+endif()
+populate_tflite_source_vars("tools/optimize" TFLITE_TOOLS_OPTIMIZE_SRCS)
+populate_tflite_source_vars("tools/optimize/calibration"
+  TFLITE_TOOLS_OPTIMIZE_CALIBRATION_SRCS
+)
+populate_tflite_source_vars("tools/optimize/calibration/builtin_logging_ops"
+  TFLITE_TOOLS_OPTIMIZE_CALIBRATION_OPS_SRCS
+)
+populate_tflite_source_vars("tools/optimize/sparsity"
+  TFLITE_TOOLS_OPTIMIZE_SPARSITY_SRCS
+)
+add_library(tensorflowlite
+  ${TFLITE_CORE_API_SRCS}
+  ${TFLITE_CORE_SRCS}
+  ${TFLITE_C_SRCS}
+  ${TFLITE_DELEGATES_FLEX_SRCS}
+  ${TFLITE_DELEGATES_NNAPI_SRCS}
+  ${TFLITE_DELEGATES_SRCS}
+  ${TFLITE_DELEGATES_XNNPACK_SRCS}
+  ${TFLITE_EXPERIMENTAL_RESOURCE_SRCS}
+  ${TFLITE_EXPERIMENTAL_RUY_PROFILER_SRCS}
+  ${TFLITE_EXPERIMENTAL_RUY_SRCS}
+  ${TFLITE_FLATBUFFERS_HDRS}
+  ${TFLITE_KERNEL_INTERNAL_OPT_INTEGER_OPS_SRCS}
+  ${TFLITE_KERNEL_INTERNAL_OPT_SPARSE_OPS_SRCS}
+  ${TFLITE_KERNEL_INTERNAL_OPT_SRCS}
+  ${TFLITE_KERNEL_INTERNAL_REF_INTEGER_OPS_SRCS}
+  ${TFLITE_KERNEL_INTERNAL_REF_SPARSE_OPS_SRCS}
+  ${TFLITE_KERNEL_INTERNAL_REF_SRCS}
+  ${TFLITE_KERNEL_INTERNAL_SRCS}
+  ${TFLITE_KERNEL_PROFILING_SRCS}
+  ${TFLITE_KERNEL_SRCS}
+  ${TFLITE_NNAPI_SRCS}
+  ${TFLITE_SRCS}
+  ${TFLITE_TOOLS_OPTIMIZE_CALIBRATION_OPS_SRCS}
+  ${TFLITE_TOOLS_OPTIMIZE_CALIBRATION_SRCS}
+  ${TFLITE_TOOLS_OPTIMIZE_SPARSITY_SRCS}
+  ${TFLITE_TOOLS_OPTIMIZE_SRCS}
+)
+target_link_libraries(tensorflowlite
+  PUBLIC
+    Eigen3::Eigen
+    NEON_2_SSE
+    absl::flags
+    absl::hash
+    absl::status
+    absl::strings
+    absl::synchronization
+    absl::variant
+    farmhash
+    fft2d_fftsg2d
+    flatbuffers
+    gemmlowp
+    ruy
+    ${TFLITE_TARGET_DEPENDENCIES}
+)
+target_include_directories(tensorflowlite
+  PUBLIC
+   "${TENSORFLOW_SOURCE_DIR}"
+  PRIVATE
+    "${TFLITE_FLATBUFFERS_GEN_DIR}"
+)
+target_compile_options(tensorflowlite
+  PUBLIC ${TFLITE_TARGET_PUBLIC_OPTIONS}
+  PRIVATE ${TFLITE_TARGET_PRIVATE_OPTIONS}
+)
+add_library(tensorflow::tensorflowlite ALIAS tensorflowlite)
diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD
index 5ac6d78..6662ca2 100644
--- a/tensorflow/lite/c/BUILD
+++ b/tensorflow/lite/c/BUILD
@@ -3,6 +3,7 @@
     "tflite_cc_shared_object",
     "tflite_copts",
 )
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -123,6 +124,7 @@
         "builtin_op_data.h",
         "common.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     alwayslink = 1,
 )
 
diff --git a/tensorflow/lite/c/c_api.h b/tensorflow/lite/c/c_api.h
index 880b80e..152bcf9 100644
--- a/tensorflow/lite/c/c_api.h
+++ b/tensorflow/lite/c/c_api.h
@@ -188,7 +188,7 @@
     const TfLiteInterpreter* interpreter);
 
 // Returns the tensor associated with the output index.
-// REQUIRES: 0 <= input_index < TfLiteInterpreterGetOutputTensorCount(tensor)
+// REQUIRES: 0 <= output_index < TfLiteInterpreterGetOutputTensorCount(tensor)
 //
 // NOTE: The shape and underlying data buffer for output tensors may be not
 // be available until after the output tensor has been both sized and allocated.
diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h
index d320a90..31405df 100644
--- a/tensorflow/lite/c/common.h
+++ b/tensorflow/lite/c/common.h
@@ -226,6 +226,17 @@
     }                                                                      \
   } while (0)
 
+#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon)                          \
+  do {                                                                       \
+    auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a));                    \
+    if (delta > epsilon) {                                                   \
+      TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)",       \
+                         __FILE__, __LINE__, #a, #b, static_cast<double>(a), \
+                         static_cast<double>(b));                            \
+      return kTfLiteError;                                                   \
+    }                                                                        \
+  } while (0)
+
 #define TF_LITE_ENSURE_OK(context, status) \
   do {                                     \
     const TfLiteStatus s = (status);       \
diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD
index a1e6fc4..4511171 100644
--- a/tensorflow/lite/core/api/BUILD
+++ b/tensorflow/lite/core/api/BUILD
@@ -1,5 +1,6 @@
 load("//tensorflow/lite:build_def.bzl", "tflite_copts")
 load("//tensorflow/lite/micro:build_def.bzl", "micro_copts")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -21,6 +22,7 @@
         "profiler.h",
         "tensor_utils.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + micro_copts(),
     deps = [
         "@flatbuffers//:runtime_cc",
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index ecdb04c..f3e8b8a 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -674,13 +674,17 @@
       continue;
     }
 
-    // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be
-    // allocated after the initial `PrepareOpsAndTensors()` is called.
-    TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type,
-                      kTfLiteArenaRwPersistent);
-    TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr);
-
-    tflite::ResetVariableTensor(&tensor);
+    if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
+      // If variable tensors allocation type is `kTfLiteArenaRwPersistent`, then
+      // they must be allocated after the initial `PrepareOpsAndTensors()` is
+      // called.
+      TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr);
+      tflite::ResetVariableTensor(&tensor);
+    } else {
+      // If variable tensors allocation type is not `kTfLiteArenaRwPersistent`,
+      // then it can only be `kTfLiteCustom` in which case, we do not reset it.
+      TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, kTfLiteCustom);
+    }
   }
   return kTfLiteOk;
 }
@@ -1486,8 +1490,10 @@
 TfLiteStatus Subgraph::SetCustomAllocationForTensor(
     int tensor_index, const TfLiteCustomAllocation& allocation) {
   TfLiteTensor* tensor = &context_.tensors[tensor_index];
-  TF_LITE_ENSURE(context(), tensor->allocation_type == kTfLiteArenaRw ||
-                                tensor->allocation_type == kTfLiteCustom);
+  TF_LITE_ENSURE(context(),
+                 (tensor->allocation_type == kTfLiteArenaRw ||
+                  tensor->allocation_type == kTfLiteArenaRwPersistent ||
+                  tensor->allocation_type == kTfLiteCustom));
   TF_LITE_ENSURE_STATUS(
       ValidateCustomAllocationForTensor(context(), tensor, allocation));
 
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 3a28b4c..4e40d20 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -342,8 +342,8 @@
   // for the tensor, it can no longer be reset to the TFLite arena memory.
   //
   // Parameters should satisfy the following conditions:
-  // 1. tensor->allocation_type == kTfLiteArenaRw
-  //    In general, this is true for all non-constants such as I/O tensors.
+  // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent
+  //    In general, this is true for I/O tensors & variable tensors.
   // 2. allocation->data has the appropriate permissions for runtime access
   //    (Read-only for inputs, Read-Write for others), and outlives Interpreter.
   // 3. allocation->bytes >= tensor->bytes.
diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD
index e1f91f3..3998a3c 100644
--- a/tensorflow/lite/delegates/BUILD
+++ b/tensorflow/lite/delegates/BUILD
@@ -14,6 +14,7 @@
 # ==============================================================================
 
 load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -23,6 +24,7 @@
 cc_library(
     name = "status",
     hdrs = ["status.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     deps = [
         "//tensorflow/lite/c:common",
@@ -33,6 +35,7 @@
     name = "utils",
     srcs = ["utils.cc"],
     hdrs = ["utils.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     deps = [
         "//tensorflow/lite:kernel_api",
diff --git a/tensorflow/lite/delegates/external/README.md b/tensorflow/lite/delegates/external/README.md
index d110ded..0194518 100644
--- a/tensorflow/lite/delegates/external/README.md
+++ b/tensorflow/lite/delegates/external/README.md
@@ -23,7 +23,7 @@
 void tflite_plugin_destroy_delegate(TfLiteDelegate* delegate)
 ```
 
-The external delegate provides an opague and transparent way to utilize a
+The external delegate provides an opaque and transparent way to utilize a
 Tensorflow Lite delegate when performing inference. In other words, one may
 replace the actual Tensorflow Lite delegate by simply updating the dynamic
 library without changing the application code. We developed this mainly for
diff --git a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
index eefbeb7..1168196 100644
--- a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
+++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc
@@ -75,6 +75,7 @@
           "BiasAdd",
           "BiasAddGrad",
           "BiasAddV1",
+          "Bincount",
           "BoostedTreesBucketize",
           "BroadcastArgs",
           "BroadcastGradientArgs",
@@ -116,6 +117,7 @@
           "DecodeWav",
           "DeepCopy",
           "DeleteSessionTensor",
+          "DenseBincount",
           "DepthToSpace",
           "DepthwiseConv2dNative",
           "Dequantize",
@@ -302,6 +304,7 @@
           "RFFT",
           "RFFT2D",
           "RFFT3D",
+          "RaggedBincount",
           "RaggedRange",
           "RaggedTensorToSparse",
           "RaggedTensorToTensor",
@@ -416,6 +419,7 @@
           "SparseApplyProximalAdagrad",
           "SparseApplyProximalGradientDescent",
           "SparseApplyRMSProp",
+          "SparseBincount",
           "SparseCross",
           "SparseCrossHashed",
           "SparseCrossV2",
diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc
index 2be9280..8e3ed96 100644
--- a/tensorflow/lite/delegates/flex/delegate_data.cc
+++ b/tensorflow/lite/delegates/flex/delegate_data.cc
@@ -46,7 +46,6 @@
   eager_context_ = new tensorflow::EagerContext(
       session_options,
       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
-      tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
       /*async=*/false, /*lazy_copy_function_remote_inputs=*/false,
       device_mgr.release(), /*device_mgr_owned*/ true, rendezvous, nullptr);
   return tensorflow::Status();
diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index 9ae3836..a94805e 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -76,7 +76,10 @@
     ],
     deps = [
         ":arguments",
+        ":buffer",
+        ":device_info",
         ":gpu_object",
+        ":tensor",
         ":tensor_type",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "@com_google_absl//absl/strings",
diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc
index 5623de2..b7e6b08 100644
--- a/tensorflow/lite/delegates/gpu/cl/arguments.cc
+++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc
@@ -256,13 +256,6 @@
   object_refs_[name] = {std::move(descriptor_ptr)};
 }
 
-void Arguments::AddObject(const std::string& name, AccessType access_type,
-                          GPUObjectPtr&& object,
-                          GPUObjectDescriptorPtr&& descriptor_ptr) {
-  descriptor_ptr->SetAccess(access_type);
-  objects_[name] = {std::move(object), std::move(descriptor_ptr)};
-}
-
 void Arguments::AddObject(const std::string& name,
                           GPUObjectDescriptorPtr&& descriptor_ptr) {
   descriptor_ptr->SetAccess(AccessType::READ);
diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h
index 643e1b7..4636a06 100644
--- a/tensorflow/lite/delegates/gpu/cl/arguments.h
+++ b/tensorflow/lite/delegates/gpu/cl/arguments.h
@@ -39,37 +39,16 @@
   void AddFloat(const std::string& name, float value = 0.0f);
   void AddHalf(const std::string& name, half value = half(0.0f));
   void AddInt(const std::string& name, int value = 0);
-  void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
-  void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
-  void AddImage2DArray(const std::string& name,
-                       const GPUImage2DArrayDescriptor& desc);
-  void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
-  void AddImageBuffer(const std::string& name,
-                      const GPUImageBufferDescriptor& desc);
-  void AddCustomMemory(const std::string& name,
-                       const GPUCustomMemoryDescriptor& desc);
-
   void AddObjectRef(const std::string& name, AccessType access_type,
                     GPUObjectDescriptorPtr&& descriptor_ptr);
-  void AddObject(const std::string& name, AccessType access_type,
-                 GPUObjectPtr&& object,
-                 GPUObjectDescriptorPtr&& descriptor_ptr);
   void AddObject(const std::string& name,
                  GPUObjectDescriptorPtr&& descriptor_ptr);
 
   absl::Status SetInt(const std::string& name, int value);
   absl::Status SetFloat(const std::string& name, float value);
   absl::Status SetHalf(const std::string& name, half value);
-  absl::Status SetImage2D(const std::string& name, cl_mem memory);
-  absl::Status SetBuffer(const std::string& name, cl_mem memory);
-  absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
-  absl::Status SetImage3D(const std::string& name, cl_mem memory);
-  absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
-  absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
   absl::Status SetObjectRef(const std::string& name, const GPUObject* object);
 
-  std::string GetListOfArgs();
-
   absl::Status Bind(cl_kernel kernel, int offset = 0);
 
   void RenameArgs(const std::string& postfix, std::string* code) const;
@@ -87,6 +66,25 @@
   Arguments& operator=(const Arguments&) = delete;
 
  private:
+  void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc);
+  void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc);
+  void AddImage2DArray(const std::string& name,
+                       const GPUImage2DArrayDescriptor& desc);
+  void AddImage3D(const std::string& name, const GPUImage3DDescriptor& desc);
+  void AddImageBuffer(const std::string& name,
+                      const GPUImageBufferDescriptor& desc);
+  void AddCustomMemory(const std::string& name,
+                       const GPUCustomMemoryDescriptor& desc);
+
+  absl::Status SetImage2D(const std::string& name, cl_mem memory);
+  absl::Status SetBuffer(const std::string& name, cl_mem memory);
+  absl::Status SetImage2DArray(const std::string& name, cl_mem memory);
+  absl::Status SetImage3D(const std::string& name, cl_mem memory);
+  absl::Status SetImageBuffer(const std::string& name, cl_mem memory);
+  absl::Status SetCustomMemory(const std::string& name, cl_mem memory);
+
+  std::string GetListOfArgs();
+
   std::string AddActiveArgument(const std::string& arg_name,
                                 bool use_f32_for_halfs);
   void AddGPUResources(const std::string& name, const GPUResources& resources);
diff --git a/tensorflow/lite/delegates/gpu/cl/arguments_test.cc b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc
index 29a15e1..722ca5b 100644
--- a/tensorflow/lite/delegates/gpu/cl/arguments_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/arguments_test.cc
@@ -14,85 +14,58 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/gpu/cl/arguments.h"
 
+#include <cstdint>
 #include <string>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/strings/match.h"
+#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
+#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
 #include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
-#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
-namespace {
-struct TestDescriptor : public GPUObjectDescriptor {
-  absl::Status PerformSelector(const std::string& selector,
-                               const std::vector<std::string>& args,
-                               const std::vector<std::string>& template_args,
-                               std::string* result) const override {
-    if (selector == "Length") {
-      *result = "length";
-      return absl::OkStatus();
-    } else if (selector == "Read") {
-      if (args.size() != 1) {
-        return absl::NotFoundError(
-            absl::StrCat("TestDescriptor Read require one argument, but ",
-                         args.size(), " was passed"));
-      }
-      *result = absl::StrCat("buffer[", args[0], "]");
-      return absl::OkStatus();
-    } else {
-      return absl::NotFoundError(absl::StrCat(
-          "TestDescriptor don't have selector with name - ", selector));
-    }
-  }
-
-  GPUResources GetGPUResources(AccessType access_type) const override {
-    GPUResources resources;
-    resources.ints.push_back("length");
-    GPUBufferDescriptor desc;
-    desc.data_type = DataType::FLOAT32;
-    desc.element_size = 4;
-    resources.buffers.push_back({"buffer", desc});
-    return resources;
-  }
-};
-}  // namespace
-
 TEST(ArgumentsTest, TestSelectorResolve) {
-  TestDescriptor descriptor;
-  Arguments args;
-  args.AddObjectRef("object", AccessType::WRITE,
-                    absl::make_unique<TestDescriptor>(descriptor));
-  std::string sample_code = R"(
-  if (a < 3) {
-    value = args.object.Read(id);
-  }
-)";
-  const std::string expected_result = R"(
-  if (a < 3) {
-    value = object_buffer[id];
-  }
-)";
-  ASSERT_OK(args.TransformToCLCode({}, &sample_code));
-  EXPECT_EQ(sample_code, expected_result);
+  BufferDescriptor desc;
+  desc.element_type = DataType::FLOAT32;
+  desc.element_size = 4;
+  desc.memory_type = MemoryType::GLOBAL;
 
-  std::string cl_arguments = args.GetListOfArgs();
-  EXPECT_TRUE(cl_arguments.find("__global float4* object_buffer") !=
-              std::string::npos);
+  Arguments args;
+  args.AddObjectRef("weights", AccessType::READ,
+                    absl::make_unique<BufferDescriptor>(std::move(desc)));
+  std::string sample_code = R"(
+__kernel void main_function($0) {
+  if (a < 3) {
+    value = args.weights.Read(id);
+  }
+})";
+
+  DeviceInfo device_info;
+  ASSERT_OK(args.TransformToCLCode(device_info, {}, &sample_code));
+  EXPECT_TRUE(absl::StrContains(sample_code, "value = weights_buffer[id];"));
+  EXPECT_TRUE(
+      absl::StrContains(sample_code, "__global float4* weights_buffer"));
 }
 
 TEST(ArgumentsTest, TestNoSelector) {
-  TestDescriptor descriptor;
+  BufferDescriptor desc;
+  desc.element_type = DataType::FLOAT32;
+  desc.element_size = 4;
+  desc.memory_type = MemoryType::GLOBAL;
+
   Arguments args;
-  args.AddObjectRef("object", AccessType::WRITE,
-                    absl::make_unique<TestDescriptor>(descriptor));
+  args.AddObjectRef("weights", AccessType::READ,
+                    absl::make_unique<BufferDescriptor>(std::move(desc)));
   std::string sample_code = R"(
   if (a < 3) {
-    value = args.object.Write(id);
+    value = args.weights.UnknownSelector(id);
   }
 )";
-  EXPECT_FALSE(args.TransformToCLCode({}, &sample_code).ok());
+  DeviceInfo device_info;
+  EXPECT_FALSE(args.TransformToCLCode(device_info, {}, &sample_code).ok());
 }
 
 TEST(ArgumentsTest, TestRenameArgs) {
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index 2d40333..9cb8dde 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -160,7 +160,7 @@
   creation_context.queue = env->queue();
   creation_context.cache = env->program_cache();
 
-  ReserveGraphTensors(create_info, creation_context, graph);
+  ReserveGraphTensors(create_info, creation_context.GetDeviceInfo(), graph);
   precision_ = create_info.precision;
   storage_type_ = create_info.storage_type;
   if (env->device().IsMali()) {
@@ -174,10 +174,10 @@
     need_flush_ = true;
   }
   CopyInAndOutIds(graph);
-  RETURN_IF_ERROR(
-      ConvertOperations(creation_context, graph, create_info.hints));
+  RETURN_IF_ERROR(ConvertOperations(creation_context.GetDeviceInfo(), graph,
+                                    create_info.hints));
   RETURN_IF_ERROR(Merge());
-  RETURN_IF_ERROR(AllocateMemory(env->device(), creation_context.context));
+  RETURN_IF_ERROR(AllocateMemory(creation_context.context));
   BindMemoryToOperations();
   RETURN_IF_ERROR(Compile(creation_context));
   RETURN_IF_ERROR(UpdateParams());
@@ -213,8 +213,8 @@
 }
 
 void InferenceContext::ReserveGraphTensors(
-    const CreateInferenceInfo& create_info,
-    const CreationContext& creation_context, const GraphFloat32& graph) {
+    const CreateInferenceInfo& create_info, const DeviceInfo& device_info,
+    const GraphFloat32& graph) {
   ValueId max_id;
   auto tensors = graph.values();
   auto data_type = DeduceDataTypeFromPrecision(create_info.precision);
@@ -225,14 +225,14 @@
     if (graph.IsGraphInput(t->id) || graph.IsGraphOutput(t->id)) {
       if (shape.c < 4 &&
           CanCreateTensorWithShape(
-              creation_context.device->info_, shape,
+              device_info, shape,
               TensorDescriptor{data_type, TensorStorageType::SINGLE_TEXTURE_2D,
                                layout})) {
         storage_type = TensorStorageType::SINGLE_TEXTURE_2D;
       }
     }
-    storage_type = SelectBestStorageType(creation_context.device->info_, shape,
-                                         storage_type, data_type, layout);
+    storage_type = SelectBestStorageType(device_info, shape, storage_type,
+                                         data_type, layout);
     tensor_reserver_.Add(
         t->id, {shape, TensorDescriptor{data_type, storage_type, layout}});
     max_id = std::max(max_id, t->id);
@@ -240,9 +240,9 @@
   tensor_reserver_.SetNext(max_id + 1);
 }
 
-absl::Status InferenceContext::ConvertOperations(
-    const CreationContext& creation_context, const GraphFloat32& graph,
-    ModelHints hints) {
+absl::Status InferenceContext::ConvertOperations(const DeviceInfo& device_info,
+                                                 const GraphFloat32& graph,
+                                                 ModelHints hints) {
   std::map<ValueId, TensorDescriptor> tensor_descriptors;
   const auto values = graph.values();
   for (auto value : values) {
@@ -263,9 +263,8 @@
     }
     GPUOperationsSubgraph gpu_subgraph;
     if (hints.Check(ModelHints::kAllowSpecialKernels) &&
-        GPUSubgraphFromGraph(creation_context.device->info_, precision_, graph,
-                             node.id, tensor_descriptors, &consumed_nodes,
-                             &gpu_subgraph)
+        GPUSubgraphFromGraph(device_info, precision_, graph, node.id,
+                             tensor_descriptors, &consumed_nodes, &gpu_subgraph)
             .ok()) {
       // Mapping of subgraph (set of nodes) to GPU operations. Should happen
       // before straigtforward mapping.
@@ -303,9 +302,8 @@
         op_def.dst_tensors.push_back(
             tensor_reserver_.Get(outputs[j]->id).descriptor);
       }
-      RETURN_IF_ERROR(GPUOperationFromNode(creation_context, op_def, hints,
-                                           inputs, outputs, node,
-                                           &gpu_subgraph));
+      RETURN_IF_ERROR(GPUOperationFromNode(device_info, op_def, hints, inputs,
+                                           outputs, node, &gpu_subgraph));
     }
     absl::flat_hash_map<int, ValueId> mapping_to_global_ids;
     for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
@@ -414,15 +412,13 @@
   }
 }
 
-absl::Status InferenceContext::AllocateMemory(const CLDevice& device,
-                                              CLContext* context) {
-  RETURN_IF_ERROR(AllocateMemoryForBuffers(device, context));
-  RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device, context));
+absl::Status InferenceContext::AllocateMemory(CLContext* context) {
+  RETURN_IF_ERROR(AllocateMemoryForBuffers(context));
+  RETURN_IF_ERROR(AllocateMemoryForStrongShapes(context));
   return absl::OkStatus();
 }
 
-absl::Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device,
-                                                        CLContext* context) {
+absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) {
   std::map<ValueId, int2> buffer_usages;
   GetUsages(
       [](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); },
@@ -474,7 +470,7 @@
 }
 
 absl::Status InferenceContext::AllocateMemoryForStrongShapes(
-    const CLDevice& device, CLContext* context) {
+    CLContext* context) {
   std::map<ValueId, int2> usages;
   GetUsages(
       [](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); },
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h
index ab165f0..8486f2d 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.h
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h
@@ -89,20 +89,18 @@
 
  private:
   void CopyInAndOutIds(const GraphFloat32& graph);
-  absl::Status ConvertOperations(const CreationContext& creation_context,
+  absl::Status ConvertOperations(const DeviceInfo& device_info,
                                  const GraphFloat32& graph, ModelHints hints);
   void CreateLinks();
   void ReserveGraphTensors(const CreateInferenceInfo& create_info,
-                           const CreationContext& creation_context,
+                           const DeviceInfo& device_info,
                            const GraphFloat32& graph);
   absl::Status Merge();
-  absl::Status AllocateMemory(const CLDevice& device, CLContext* context);
+  absl::Status AllocateMemory(CLContext* context);
 
-  absl::Status AllocateMemoryForBuffers(const CLDevice& device,
-                                        CLContext* context);
+  absl::Status AllocateMemoryForBuffers(CLContext* context);
 
-  absl::Status AllocateMemoryForStrongShapes(const CLDevice& device,
-                                             CLContext* context);
+  absl::Status AllocateMemoryForStrongShapes(CLContext* context);
 
   // utility function
   void GetUsages(const std::function<bool(const TensorDescriptor&)>& functor,
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
index 0843fe5..15681be 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
@@ -941,6 +941,7 @@
         "//tensorflow/lite/delegates/gpu/cl:cl_context",
         "//tensorflow/lite/delegates/gpu/cl:cl_kernel",
         "//tensorflow/lite/delegates/gpu/cl:linear_storage",
+        "//tensorflow/lite/delegates/gpu/cl:storage_type_util",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
@@ -1005,6 +1006,37 @@
 )
 
 cc_library(
+    name = "reduce",
+    srcs = ["reduce.cc"],
+    hdrs = ["reduce.h"],
+    deps = [
+        ":gpu_operation",
+        ":util",
+        "//tensorflow/lite/delegates/gpu/cl:precision",
+        "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:types",
+    ],
+)
+
+cc_test(
+    name = "reduce_test",
+    srcs = ["reduce_test.cc"],
+    linkstatic = True,
+    tags = tf_gpu_tests_tags() + [
+        "linux",
+        "local",
+    ],
+    deps = [
+        ":cl_test",
+        ":reduce",
+        "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
+
+cc_library(
     name = "relu",
     srcs = ["relu.cc"],
     hdrs = ["relu.h"],
@@ -1396,6 +1428,7 @@
         "padding_test",
         "pooling_test",
         "prelu_test",
+        "reduce_test",
         "relu_test",
         "reshape_test",
         "reshapex4_test",
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc
index dc54286..c366363 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.cc
@@ -45,84 +45,29 @@
     return GetAdrenoOptimalMaxConstantSize(info.adreno_info.gpu_version);
   }
 }
-}  // namespace
 
-ConvConstants::ConvConstants(const OperationDef& definition,
-                             const Convolution2DAttributes& attr,
-                             const DeviceInfo& device_info)
-    : GPUOperation(definition),
-      kernel_size_(attr.weights.shape.w, attr.weights.shape.h),
-      stride_(attr.strides.w, attr.strides.h),
-      padding_(-attr.padding.prepended.w, -attr.padding.prepended.h),
-      dilation_(attr.dilations.w, attr.dilations.h),
-      src_channels_(attr.weights.shape.i),
-      dst_channels_(attr.weights.shape.o) {
-  const bool stride_correction =
-      definition_.IsBatchSupported() && stride_.x != 1;
-  code_ =
-      GenerateConvolutionConstantCode(definition_, kernel_size_, src_channels_,
-                                      dst_channels_, stride_correction);
-  if (definition_.precision == CalculationsPrecision::F16 &&
-      device_info.IsAdreno3xx()) {
-    compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE);
-  }
-  if (definition_.precision != CalculationsPrecision::F32 &&
-      device_info.IsPowerVR()) {
-    // BUG, some PowerVRs (GE8320) produce incorrect result without it
-    compiler_options_.push_back(CompilerOptions::CL_OPT_DISABLE);
-  }
-}
-
-ConvConstants::ConvConstants(ConvConstants&& kernel)
-    : GPUOperation(std::move(kernel)),
-      kernel_size_(kernel.kernel_size_),
-      stride_(kernel.stride_),
-      padding_(kernel.padding_),
-      dilation_(kernel.dilation_),
-      src_channels_(kernel.src_channels_),
-      dst_channels_(kernel.dst_channels_) {}
-
-ConvConstants& ConvConstants::operator=(ConvConstants&& kernel) {
-  if (this != &kernel) {
-    std::swap(kernel_size_, kernel.kernel_size_);
-    std::swap(stride_, kernel.stride_);
-    std::swap(padding_, kernel.padding_);
-    std::swap(dilation_, kernel.dilation_);
-    std::swap(src_channels_, kernel.src_channels_);
-    std::swap(dst_channels_, kernel.dst_channels_);
-    GPUOperation::operator=(std::move(kernel));
-  }
-  return *this;
-}
-
-std::string ConvConstants::GenerateConvolutionConstantCode(
-    const OperationDef& op_def, const int2& kernel_size, int src_channels,
-    int dst_channels, bool stride_correction) {
+std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
+                                            const OHWI& weights_shape,
+                                            bool stride_correction,
+                                            GPUOperation* op) {
   auto src_desc = op_def.src_tensors[0];
   src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
   if (op_def.IsBatchSupported()) {
     src_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_tensor", src_desc);
+  op->AddSrcTensor("src_tensor", src_desc);
 
   auto dst_desc = op_def.dst_tensors[0];
   if (op_def.IsBatchSupported()) {
     dst_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddDstTensor("dst_tensor", dst_desc);
-
-  args_.AddInt("stride_x");
-  args_.AddInt("stride_y");
-  args_.AddInt("padding_x");
-  args_.AddInt("padding_y");
-  args_.AddInt("dilation_x");
-  args_.AddInt("dilation_y");
+  op->AddDstTensor("dst_tensor", dst_desc);
 
   std::string c = GetCommonDefines(op_def.precision);
 
-  const int out_z = DivideRoundUp(dst_channels, 4);
+  const int out_z = DivideRoundUp(weights_shape.o, 4);
   const std::string kOutZ = std::to_string(out_z);
-  const int src_depth = DivideRoundUp(src_channels, 4);
+  const int src_depth = DivideRoundUp(weights_shape.i, 4);
 
   const auto src_tensor_type = op_def.src_tensors[0].storage_type;
   const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
@@ -176,11 +121,16 @@
        "return;\n";
   if (stride_correction) {
     c += "  int start_x = " +
-         GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
-                             "args.padding_x") +
+         GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x",
+                               "args.padding_x") +
          ";\n";
   } else {
-    c += "  int start_x = X * args.stride_x + args.padding_x;\n";
+    if (op_def.IsBatchSupported()) {
+      c += "  int start_x = X * args.stride_x + args.padding_x * "
+           "args.src_tensor.Batch();\n";
+    } else {
+      c += "  int start_x = X * args.stride_x + args.padding_x;\n";
+    }
   }
   c += "  int start_y = Y * args.stride_y + args.padding_y;\n";
   c += "  ACCUM_FLT4 r[" + kOutZ + "];\n";
@@ -189,22 +139,25 @@
   c += "  }\n";
   int filters_counter = 0;
   for (int s = 0; s < src_depth; ++s) {
-    const int ch_count = std::min(4, src_channels - s * 4);
+    const int ch_count = std::min(4, weights_shape.i - s * 4);
     const std::string s_conv = "CONV" + std::to_string(ch_count);
     const std::string s_count = ch_count == 1 ? "" : std::to_string(ch_count);
     const std::string s_type = absl::StrCat("FLT", s_count);
     const std::string s_postfix = postfixes[ch_count - 1];
-    for (int ky = 0; ky < kernel_size.y; ++ky) {
+    const std::string dilation_x =
+        op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
+                                  : "args.dilation_x";
+    for (int ky = 0; ky < weights_shape.h; ++ky) {
       std::string s_y = absl::StrCat("(start_y + ", ky, " * args.dilation_y)");
       if (manual_clamp) {
         c += "  {\n";
         c += "  bool y_out = " + s_y + " < 0 || " + s_y +
              " >= args.src_tensor.Height();\n";
       }
-      for (int kx = 0; kx < kernel_size.x; ++kx) {
+      for (int kx = 0; kx < weights_shape.w; ++kx) {
         c += "  {\n";
         std::string s_x =
-            absl::StrCat("(start_x + ", kx, " * args.dilation_x)");
+            absl::StrCat("(start_x + ", kx, " * " + dilation_x + ")");
         if (manual_clamp) {
           c += "    bool x_out = " + s_x + "< 0 || " + s_x +
                ">= args.src_tensor.Width();\n";
@@ -240,20 +193,7 @@
   return c;
 }
 
-absl::Status ConvConstants::BindArguments() {
-  RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
-  RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
-  RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
-  RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
-  RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch()));
-  return args_.SetInt("dilation_y", dilation_.y);
-}
-
-int3 ConvConstants::GetGridSize() const {
-  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
-  const int grid_y = dst_[0]->Height();
-  return int3(grid_x, grid_y, 1);
-}
+}  // namespace
 
 bool IsConvConstantsSupported(const DeviceInfo& device_info,
                               const OperationDef& definition,
@@ -277,20 +217,41 @@
   return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8;
 }
 
-ConvConstants CreateConvConstants(const DeviceInfo& device_info,
-                                  const OperationDef& definition,
-                                  const Convolution2DAttributes& attr) {
-  ConvConstants result(definition, attr, device_info);
-  result.UploadWeights(attr.weights);
+GPUOperation CreateConvConstants(const DeviceInfo& device_info,
+                                 const OperationDef& definition,
+                                 const Convolution2DAttributes& attr) {
+  GPUOperation op(definition);
+  UploadWeightsForConvConstants(attr.weights, definition.precision, &op);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+  op.args_.AddInt("dilation_x", attr.dilations.w);
+  op.args_.AddInt("dilation_y", attr.dilations.h);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1;
+
+  const bool stride_correction =
+      definition.IsBatchSupported() && attr.strides.w != 1;
+  op.code_ = GenerateConvolutionConstantCode(definition, attr.weights.shape,
+                                             stride_correction, &op);
+  if (definition.precision == CalculationsPrecision::F16 &&
+      device_info.IsAdreno3xx()) {
+    op.compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE);
+  }
+  if (definition.precision != CalculationsPrecision::F32 &&
+      device_info.IsPowerVR()) {
+    // BUG, some PowerVRs (GE8320) produce incorrect result without it
+    op.compiler_options_.push_back(CompilerOptions::CL_OPT_DISABLE);
+  }
 
   TensorLinearDescriptor desc;
   desc.storage_type = LinearStorageType::BUFFER;
   desc.element_type = definition.GetDataType();
   desc.memory_type = MemoryType::CONSTANT;
   desc.UploadLinearData(attr.bias);
-  result.args_.AddObject(
+  op.args_.AddObject(
       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
-  return result;
+  return op;
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h
index 5be4335..c341ecb 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h
@@ -32,78 +32,8 @@
 namespace gpu {
 namespace cl {
 
-class ConvConstants : public GPUOperation {
- public:
-  ConvConstants() = default;
-  absl::Status BindArguments() override;
-  int3 GetGridSize() const override;
-
-  // Move only
-  ConvConstants(ConvConstants&& kernel);
-  ConvConstants& operator=(ConvConstants&& kernel);
-  ConvConstants(const ConvConstants&) = delete;
-  ConvConstants& operator=(const ConvConstants&) = delete;
-
- private:
-  friend ConvConstants CreateConvConstants(const DeviceInfo& device_info,
-                                           const OperationDef& definition,
-                                           const Convolution2DAttributes& attr);
-  ConvConstants(const OperationDef& definition,
-                const Convolution2DAttributes& attr,
-                const DeviceInfo& device_info);
-
-  template <DataType T>
-  void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights);
-
-  template <DataType S, typename T>
-  void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
-                            absl::Span<T> dst);
-
-  std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
-                                              const int2& kernel_size,
-                                              int src_channels,
-                                              int dst_channels,
-                                              bool stride_correction);
-
-  int2 kernel_size_;
-  int2 stride_;
-  int2 padding_;
-  int2 dilation_;
-  int src_channels_;
-  int dst_channels_;
-};
-
-template <DataType T>
-void ConvConstants::UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights) {
-  const int dst_depth = DivideRoundUp(weights.shape.o, 4);
-  const int kernel_x = weights.shape.w;
-  const int kernel_y = weights.shape.h;
-
-  const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
-  const int float_size = f32_weights ? 4 : 2;
-  const int float_count = src_channels_ * dst_depth * 4 * kernel_x * kernel_y;
-
-  BufferDescriptor desc;
-  desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
-  desc.element_size = 4;
-  desc.memory_type = MemoryType::CONSTANT;
-  desc.size = float_size * float_count;
-  desc.data.resize(desc.size);
-
-  if (f32_weights) {
-    float4* ptr = reinterpret_cast<float4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, float_count / 4));
-  } else {
-    half4* ptr = reinterpret_cast<half4*>(desc.data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, float_count / 4));
-  }
-
-  args_.AddObject("weigths",
-                  absl::make_unique<BufferDescriptor>(std::move(desc)));
-}
-
 template <DataType S, typename T>
-void ConvConstants::RearrangeWeightsData(
+void RearrangeWeightsForConvConstants(
     const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
   const int dst_depth = DivideRoundUp(weights.shape.o, 4);
   const int src_depth = DivideRoundUp(weights.shape.i, 4);
@@ -115,7 +45,7 @@
     for (int y = 0; y < kernel_y; ++y) {
       for (int x = 0; x < kernel_x; ++x) {
         for (int d = 0; d < dst_depth; ++d) {
-          const int channels_count = std::min(4, src_channels_ - s * 4);
+          const int channels_count = std::min(4, weights.shape.i - s * 4);
           T filters[4];
           for (int i = 0; i < 4; ++i) {
             for (int j = 0; j < channels_count; ++j) {
@@ -145,13 +75,46 @@
   }
 }
 
+template <DataType T>
+void UploadWeightsForConvConstants(const tflite::gpu::Tensor<OHWI, T>& weights,
+                                   CalculationsPrecision precision,
+                                   GPUOperation* op) {
+  const int dst_depth = DivideRoundUp(weights.shape.o, 4);
+  const int kernel_x = weights.shape.w;
+  const int kernel_y = weights.shape.h;
+
+  const bool f32_weights = precision == CalculationsPrecision::F32;
+  const int float_size = f32_weights ? 4 : 2;
+  const int float_count = weights.shape.i * dst_depth * 4 * kernel_x * kernel_y;
+
+  BufferDescriptor desc;
+  desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+  desc.element_size = 4;
+  desc.memory_type = MemoryType::CONSTANT;
+  desc.size = float_size * float_count;
+  desc.data.resize(desc.size);
+
+  if (f32_weights) {
+    float4* ptr = reinterpret_cast<float4*>(desc.data.data());
+    RearrangeWeightsForConvConstants(weights,
+                                     absl::MakeSpan(ptr, float_count / 4));
+  } else {
+    half4* ptr = reinterpret_cast<half4*>(desc.data.data());
+    RearrangeWeightsForConvConstants(weights,
+                                     absl::MakeSpan(ptr, float_count / 4));
+  }
+
+  op->args_.AddObject("weigths",
+                      absl::make_unique<BufferDescriptor>(std::move(desc)));
+}
+
 bool IsConvConstantsSupported(const DeviceInfo& device_info,
                               const OperationDef& definition,
                               const Convolution2DAttributes& attr);
 
-ConvConstants CreateConvConstants(const DeviceInfo& device_info,
-                                  const OperationDef& definition,
-                                  const Convolution2DAttributes& attr);
+GPUOperation CreateConvConstants(const DeviceInfo& device_info,
+                                 const OperationDef& definition,
+                                 const Convolution2DAttributes& attr);
 
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc
index 4aa60b8..17821e1 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_constants_test.cc
@@ -55,7 +55,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      ConvConstants operation =
+      GPUOperation operation =
           CreateConvConstants(creation_context_.GetDeviceInfo(), op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 2, 1), &dst_tensor));
@@ -90,7 +90,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      ConvConstants operation =
+      GPUOperation operation =
           CreateConvConstants(creation_context_.GetDeviceInfo(), op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 2, 2), &dst_tensor));
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
index 5dbb191..1852223 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
@@ -33,11 +33,9 @@
     const OperationDef& definition, const ConvolutionTransposedAttributes& attr,
     const DeviceInfo& device_info)
     : GPUOperation(definition),
-      weights_are_buffer_(device_info.IsMali()),
-      kernel_size_(attr.weights.shape.w, attr.weights.shape.h),
       stride_(attr.stride.w, attr.stride.h),
-      padding_(attr.padding.prepended.w, attr.padding.prepended.h),
       block_size_(2, 2, 2) {
+  const bool weights_are_buffer = device_info.IsMali();
   const bool is_f16 = definition.precision == CalculationsPrecision::F16;
   if (device_info.IsMali()) {
     if (device_info.mali_info.IsMidgard()) {
@@ -54,25 +52,26 @@
     block_size_.z = 1;
   }
 
+  args_.AddInt("stride_x", stride_.x);
+  args_.AddInt("stride_y", stride_.y);
+  args_.AddInt("padding_x", attr.padding.prepended.w);
+  args_.AddInt("padding_y", attr.padding.prepended.h);
+  args_.AddInt("kernel_size_x", attr.weights.shape.w);
+  args_.AddInt("kernel_size_y", attr.weights.shape.h);
   code_ = GenerateConvolutionTransposedCode(definition_, device_info,
-                                            weights_are_buffer_, block_size_);
+                                            weights_are_buffer, block_size_);
+  UploadWeights(attr.weights, weights_are_buffer);
 }
 
 ConvolutionTransposed::ConvolutionTransposed(ConvolutionTransposed&& operation)
     : GPUOperation(std::move(operation)),
-      weights_are_buffer_(operation.weights_are_buffer_),
-      kernel_size_(operation.kernel_size_),
       stride_(operation.stride_),
-      padding_(operation.padding_),
       block_size_(operation.block_size_) {}
 
 ConvolutionTransposed& ConvolutionTransposed::operator=(
     ConvolutionTransposed&& operation) {
   if (this != &operation) {
-    std::swap(weights_are_buffer_, operation.weights_are_buffer_);
-    std::swap(kernel_size_, operation.kernel_size_);
     std::swap(stride_, operation.stride_);
-    std::swap(padding_, operation.padding_);
     std::swap(block_size_, operation.block_size_);
     GPUOperation::operator=(std::move(operation));
   }
@@ -88,13 +87,6 @@
 
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
 
-  args_.AddInt("stride_x");
-  args_.AddInt("stride_y");
-  args_.AddInt("padding_x");
-  args_.AddInt("padding_y");
-  args_.AddInt("kernel_size_x");
-  args_.AddInt("kernel_size_y");
-
   const auto src_tensor_type = op_def.src_tensors[0].storage_type;
   bool image_buffer = src_tensor_type == TensorStorageType::IMAGE_BUFFER;
   bool manual_clamp =
@@ -333,15 +325,6 @@
   return c;
 }
 
-absl::Status ConvolutionTransposed::BindArguments() {
-  RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
-  RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
-  RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x));
-  RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
-  RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
-  return args_.SetInt("kernel_size_y", kernel_size_.y);
-}
-
 int3 ConvolutionTransposed::GetGridSize() const {
   const int aligned_w = AlignByN(dst_[0]->Width(), stride_.x * block_size_.x);
   const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
@@ -362,7 +345,6 @@
     const DeviceInfo& device_info, const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr) {
   ConvolutionTransposed result(definition, attr, device_info);
-  result.UploadWeights(attr.weights);
 
   TensorLinearDescriptor desc;
   desc.storage_type =
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
index d8ecaca..7939236 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
@@ -42,7 +42,6 @@
       TuningType tuning_type, const DeviceInfo& device_info,
       const KernelInfo& kernel_info,
       std::vector<int3>* work_groups) const override;
-  absl::Status BindArguments() override;
   int3 GetGridSize() const override;
 
   // Move only
@@ -59,34 +58,29 @@
                                  const ConvolutionTransposedAttributes& attr,
                                  const DeviceInfo& device_info);
   template <DataType T>
-  void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights);
+  void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
+                     bool weights_are_buffer);
 
   template <DataType S, typename T>
   void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
-                            absl::Span<T> dst);
+                            absl::Span<T> dst, bool weights_are_buffer);
 
   std::string GenerateConvolutionTransposedCode(const OperationDef& op_def,
                                                 const DeviceInfo& device_info,
                                                 bool weights_are_buffer,
                                                 const int3& block_size);
-
-  bool weights_are_buffer_;
-
-  int2 kernel_size_;
   int2 stride_;
-  int2 padding_;
-
   int3 block_size_ = int3(1, 1, 1);
 };
 
 template <DataType T>
 void ConvolutionTransposed::UploadWeights(
-    const tflite::gpu::Tensor<OHWI, T>& weights) {
+    const tflite::gpu::Tensor<OHWI, T>& weights, bool weights_are_buffer) {
   const int dst_depth =
       AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.z);
   const int src_depth = DivideRoundUp(weights.shape.i, 4);
-  const int kernel_x = kernel_size_.x;
-  const int kernel_y = kernel_size_.y;
+  const int kernel_x = weights.shape.w;
+  const int kernel_y = weights.shape.h;
 
   const int elements_count = kernel_x * kernel_y * src_depth * dst_depth * 4;
   const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
@@ -96,13 +90,15 @@
 
   if (f32_weights) {
     float4* ptr = reinterpret_cast<float4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
+    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count),
+                         weights_are_buffer);
   } else {
     half4* ptr = reinterpret_cast<half4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
+    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count),
+                         weights_are_buffer);
   }
 
-  if (weights_are_buffer_) {
+  if (weights_are_buffer) {
     BufferDescriptor desc;
     desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
     desc.element_size = 16;
@@ -148,12 +144,13 @@
 
 template <DataType S, typename T>
 void ConvolutionTransposed::RearrangeWeightsData(
-    const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
+    const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst,
+    bool weights_are_buffer) {
   const int dst_depth =
       AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.z);
   const int src_depth = DivideRoundUp(weights.shape.i, 4);
-  const int kernel_x = kernel_size_.x;
-  const int kernel_y = kernel_size_.y;
+  const int kernel_x = weights.shape.w;
+  const int kernel_y = weights.shape.h;
   int texture_width = dst_depth;
   int texture_height = src_depth * kernel_x * kernel_y;
 
@@ -177,7 +174,7 @@
                 }
               }
             }
-            if (weights_are_buffer_) {
+            if (weights_are_buffer) {
               dst[counter++] = filters[0];
               dst[counter++] = filters[1];
               dst[counter++] = filters[2];
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc
index 443a621..b2a85a8 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.cc
@@ -33,15 +33,22 @@
     const ConvolutionTransposed3DAttributes& attr,
     const DeviceInfo& device_info)
     : GPUOperation(definition),
-      weights_are_buffer_(device_info.IsMali()),
-      kernel_size_(attr.weights.shape.w, attr.weights.shape.h,
-                   attr.weights.shape.d),
       stride_(attr.stride.w, attr.stride.h, attr.stride.d),
-      padding_(attr.padding.prepended.w, attr.padding.prepended.h,
-               attr.padding.prepended.d),
       block_size_(2, 2, 1, 2) {
-  code_ = GenerateConvolutionTransposed3DCode(definition_, weights_are_buffer_,
+  bool weights_are_buffer = device_info.IsMali();
+  args_.AddInt("stride_x", stride_.x);
+  args_.AddInt("stride_y", stride_.y);
+  args_.AddInt("stride_z", stride_.z);
+  args_.AddInt("padding_x", attr.padding.prepended.w);
+  args_.AddInt("padding_y", attr.padding.prepended.h);
+  args_.AddInt("padding_z", attr.padding.prepended.d);
+  args_.AddInt("kernel_size_x", attr.weights.shape.w);
+  args_.AddInt("kernel_size_y", attr.weights.shape.h);
+  args_.AddInt("kernel_size_z", attr.weights.shape.d);
+  args_.AddInt("grid_size_s");
+  code_ = GenerateConvolutionTransposed3DCode(definition_, weights_are_buffer,
                                               block_size_);
+  UploadWeights(attr.weights, weights_are_buffer);
   if (device_info.IsPowerVR() && block_size_.y != 1) {
     bool is_texture3d = definition_.src_tensors[0].storage_type ==
                         TensorStorageType::TEXTURE_3D;
@@ -56,19 +63,13 @@
 ConvolutionTransposed3D::ConvolutionTransposed3D(
     ConvolutionTransposed3D&& operation)
     : GPUOperation(std::move(operation)),
-      weights_are_buffer_(operation.weights_are_buffer_),
-      kernel_size_(operation.kernel_size_),
       stride_(operation.stride_),
-      padding_(operation.padding_),
       block_size_(operation.block_size_) {}
 
 ConvolutionTransposed3D& ConvolutionTransposed3D::operator=(
     ConvolutionTransposed3D&& operation) {
   if (this != &operation) {
-    std::swap(weights_are_buffer_, operation.weights_are_buffer_);
-    std::swap(kernel_size_, operation.kernel_size_);
     std::swap(stride_, operation.stride_);
-    std::swap(padding_, operation.padding_);
     std::swap(block_size_, operation.block_size_);
     GPUOperation::operator=(std::move(operation));
   }
@@ -84,17 +85,6 @@
 
   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
 
-  args_.AddInt("stride_x");
-  args_.AddInt("stride_y");
-  args_.AddInt("stride_z");
-  args_.AddInt("padding_x");
-  args_.AddInt("padding_y");
-  args_.AddInt("padding_z");
-  args_.AddInt("kernel_size_x");
-  args_.AddInt("kernel_size_y");
-  args_.AddInt("kernel_size_z");
-  args_.AddInt("grid_size_s");
-
   const auto src_tensor_type = op_def.src_tensors[0].storage_type;
   bool image_buffer = src_tensor_type == TensorStorageType::IMAGE_BUFFER;
   bool manual_clamp =
@@ -370,15 +360,6 @@
 }
 
 absl::Status ConvolutionTransposed3D::BindArguments() {
-  RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
-  RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
-  RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z));
-  RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x));
-  RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
-  RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z));
-  RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
-  RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
-  RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z));
   return args_.SetInt("grid_size_s",
                       DivideRoundUp(dst_[0]->Slices(), block_size_.w));
 }
@@ -405,7 +386,6 @@
     const DeviceInfo& device_info, const OperationDef& definition,
     const ConvolutionTransposed3DAttributes& attr) {
   ConvolutionTransposed3D result(definition, attr, device_info);
-  result.UploadWeights(attr.weights);
 
   TensorLinearDescriptor desc;
   desc.storage_type =
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h
index 3285dfc..ebd674d 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3d.h
@@ -59,34 +59,30 @@
                           const ConvolutionTransposed3DAttributes& attr,
                           const DeviceInfo& device_info);
   template <DataType T>
-  void UploadWeights(const tflite::gpu::Tensor<OHWDI, T>& weights);
+  void UploadWeights(const tflite::gpu::Tensor<OHWDI, T>& weights,
+                     bool weights_are_buffer);
 
   template <DataType S, typename T>
   void RearrangeWeightsData(const tflite::gpu::Tensor<OHWDI, S>& weights,
-                            absl::Span<T> dst);
+                            absl::Span<T> dst, bool weights_are_buffer);
 
   std::string GenerateConvolutionTransposed3DCode(const OperationDef& op_def,
                                                   bool weights_are_buffer,
                                                   const int4& block_size);
 
-  bool weights_are_buffer_;
-
-  int3 kernel_size_;
   int3 stride_;
-  int3 padding_;
-
   int4 block_size_ = int4(1, 1, 1, 1);  // WHDS
 };
 
 template <DataType T>
 void ConvolutionTransposed3D::UploadWeights(
-    const tflite::gpu::Tensor<OHWDI, T>& weights) {
+    const tflite::gpu::Tensor<OHWDI, T>& weights, bool weights_are_buffer) {
   const int dst_depth =
       AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.z);
   const int src_depth = DivideRoundUp(weights.shape.i, 4);
-  const int kernel_x = kernel_size_.x;
-  const int kernel_y = kernel_size_.y;
-  const int kernel_z = kernel_size_.z;
+  const int kernel_x = weights.shape.w;
+  const int kernel_y = weights.shape.h;
+  const int kernel_z = weights.shape.d;
   int texture_width = dst_depth;
   int texture_height = src_depth * kernel_x * kernel_y * kernel_z;
 
@@ -99,13 +95,15 @@
 
   if (f32_weights) {
     float4* ptr = reinterpret_cast<float4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
+    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count),
+                         weights_are_buffer);
   } else {
     half4* ptr = reinterpret_cast<half4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
+    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count),
+                         weights_are_buffer);
   }
 
-  if (weights_are_buffer_) {
+  if (weights_are_buffer) {
     BufferDescriptor desc;
     desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
     desc.element_size = 16;
@@ -151,13 +149,14 @@
 
 template <DataType S, typename T>
 void ConvolutionTransposed3D::RearrangeWeightsData(
-    const tflite::gpu::Tensor<OHWDI, S>& weights, absl::Span<T> dst) {
+    const tflite::gpu::Tensor<OHWDI, S>& weights, absl::Span<T> dst,
+    bool weights_are_buffer) {
   const int dst_depth =
       AlignByN(DivideRoundUp(weights.shape.o, 4), block_size_.w);
   const int src_depth = DivideRoundUp(weights.shape.i, 4);
-  const int kernel_x = kernel_size_.x;
-  const int kernel_y = kernel_size_.y;
-  const int kernel_z = kernel_size_.z;
+  const int kernel_x = weights.shape.w;
+  const int kernel_y = weights.shape.h;
+  const int kernel_z = weights.shape.d;
   int texture_width = dst_depth;
   int texture_height = src_depth * kernel_x * kernel_y * kernel_z;
 
@@ -182,7 +181,7 @@
                   }
                 }
               }
-              if (weights_are_buffer_) {
+              if (weights_are_buffer) {
                 dst[counter++] = filters[0];
                 dst[counter++] = filters[1];
                 dst[counter++] = filters[2];
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc
index 2f6010b..d606a82 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc
@@ -28,20 +28,23 @@
 namespace gpu {
 namespace cl {
 ConvolutionTransposed4x4::ConvolutionTransposed4x4(
-    const OperationDef& definition, const DeviceInfo& device_info)
+    const OperationDef& definition, const DeviceInfo& device_info,
+    const ConvolutionTransposedAttributes& attr)
     : GPUOperation(definition) {
   work_group_size_ = int3(8, 4, 1);
+  WeightsUploadType weights_upload_type = WeightsUploadType::GLOBAL_MEM;
   if (device_info.IsPowerVR()) {
-    weights_upload_type_ = WeightsUploadType::LOCAL_MEM_ASYNC;
+    weights_upload_type = WeightsUploadType::LOCAL_MEM_ASYNC;
   } else if (device_info.IsNvidia() || device_info.IsIntel()) {
-    weights_upload_type_ = WeightsUploadType::LOCAL_MEM_BY_THREADS;
+    weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
   } else if (device_info.IsAMD()) {
-    weights_upload_type_ = WeightsUploadType::CONSTANT_MEM;
+    weights_upload_type = WeightsUploadType::CONSTANT_MEM;
   } else {
-    weights_upload_type_ = WeightsUploadType::GLOBAL_MEM;
+    weights_upload_type = WeightsUploadType::GLOBAL_MEM;
   }
 
-  code_ = GenerateConvolutionTransposedCode(definition_, weights_upload_type_);
+  code_ = GenerateConvolutionTransposedCode(definition_, weights_upload_type);
+  UploadWeights(attr.weights, weights_upload_type);
   if (definition_.precision == CalculationsPrecision::F16 &&
       device_info.IsPowerVR()) {
     compiler_options_.push_back(CompilerOptions::POWERVR_FP16);
@@ -50,13 +53,11 @@
 
 ConvolutionTransposed4x4::ConvolutionTransposed4x4(
     ConvolutionTransposed4x4&& operation)
-    : GPUOperation(std::move(operation)),
-      weights_upload_type_(operation.weights_upload_type_) {}
+    : GPUOperation(std::move(operation)) {}
 
 ConvolutionTransposed4x4& ConvolutionTransposed4x4::operator=(
     ConvolutionTransposed4x4&& operation) {
   if (this != &operation) {
-    std::swap(weights_upload_type_, operation.weights_upload_type_);
     GPUOperation::operator=(std::move(operation));
   }
   return *this;
@@ -317,8 +318,7 @@
 ConvolutionTransposed4x4 CreateConvolutionTransposed4x4(
     const DeviceInfo& device_info, const OperationDef& definition,
     const ConvolutionTransposedAttributes& attr) {
-  ConvolutionTransposed4x4 result(definition, device_info);
-  result.UploadWeights(attr.weights);
+  ConvolutionTransposed4x4 result(definition, device_info, attr);
 
   TensorLinearDescriptor desc;
   desc.storage_type = LinearStorageType::TEXTURE_2D;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h
index dd1084b..2577eb4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h
@@ -61,12 +61,14 @@
 
  private:
   ConvolutionTransposed4x4(const OperationDef& definition,
-                           const DeviceInfo& device_info);
+                           const DeviceInfo& device_info,
+                           const ConvolutionTransposedAttributes& attr);
   friend ConvolutionTransposed4x4 CreateConvolutionTransposed4x4(
       const DeviceInfo& device_info, const OperationDef& definition,
       const ConvolutionTransposedAttributes& attr);
   template <DataType T>
-  void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights);
+  void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
+                     WeightsUploadType weights_upload_type);
 
   template <DataType S, typename T>
   void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
@@ -74,13 +76,12 @@
 
   std::string GenerateConvolutionTransposedCode(
       const OperationDef& op_def, WeightsUploadType weights_upload_type);
-
-  WeightsUploadType weights_upload_type_;
 };
 
 template <DataType T>
 void ConvolutionTransposed4x4::UploadWeights(
-    const tflite::gpu::Tensor<OHWI, T>& weights) {
+    const tflite::gpu::Tensor<OHWI, T>& weights,
+    WeightsUploadType weights_upload_type) {
   const int src_depth = DivideRoundUp(weights.shape.i, 4);
   const int dst_depth = DivideRoundUp(weights.shape.o, 4);
   const int kernel_x = 4;  //  This operation support only 4x4 kernel
@@ -94,7 +95,7 @@
   desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
   desc.element_size = 4;
   desc.memory_type =
-      weights_upload_type_ ==
+      weights_upload_type ==
               ConvolutionTransposed4x4::WeightsUploadType::CONSTANT_MEM
           ? MemoryType::CONSTANT
           : MemoryType::GLOBAL;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc
index 91e26b2..f42ad82 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc
@@ -66,100 +66,24 @@
 
   return c;
 }
-}  // namespace
 
-DepthwiseConvolution::DepthwiseConvolution(
-    const OperationDef& definition,
-    const DepthwiseConvolution2DAttributes& attr, bool weights_are_buffer)
-    : GPUOperation(definition),
-      weights_are_buffer_(weights_are_buffer),
-      kernel_size_(attr.weights.shape.w, attr.weights.shape.h, 0, 0),
-      stride_(attr.strides.w, attr.strides.h, 0, 0),
-      padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
-      dilation_(attr.dilations.w, attr.dilations.h, 0, 0),
-      channel_multiplier_(attr.weights.shape.o) {
-  work_group_size_ = int3(8, 8, 1);
-  const bool stride_correction =
-      definition_.IsBatchSupported() && stride_.x != 1;
-  code_ = GenerateDepthwiseConvolutionCode(
-      definition_, stride_correction, channel_multiplier_, weights_are_buffer_);
-}
-
-DepthwiseConvolution::DepthwiseConvolution(
-    const OperationDef& definition,
-    const DepthwiseConvolution3DAttributes& attr, bool weights_are_buffer)
-    : GPUOperation(definition),
-      weights_are_buffer_(weights_are_buffer),
-      kernel_size_(attr.weights.shape.w, attr.weights.shape.h,
-                   attr.weights.shape.d, 0),
-      stride_(attr.strides.w, attr.strides.h, attr.strides.d, 0),
-      padding_(-attr.padding.prepended.w, -attr.padding.prepended.h,
-               -attr.padding.prepended.d, 0),
-      dilation_(attr.dilations.w, attr.dilations.h, attr.dilations.d, 0),
-      channel_multiplier_(attr.weights.shape.o) {
-  work_group_size_ = int3(8, 8, 1);
-  const bool stride_correction =
-      definition_.IsBatchSupported() && stride_.x != 1;
-  code_ = GenerateDepthwiseConvolutionCode(
-      definition_, stride_correction, channel_multiplier_, weights_are_buffer_);
-}
-
-DepthwiseConvolution::DepthwiseConvolution(DepthwiseConvolution&& operation)
-    : GPUOperation(std::move(operation)),
-      weights_are_buffer_(operation.weights_are_buffer_),
-      kernel_size_(operation.kernel_size_),
-      stride_(operation.stride_),
-      padding_(operation.padding_),
-      dilation_(operation.dilation_),
-      channel_multiplier_(operation.channel_multiplier_) {}
-
-DepthwiseConvolution& DepthwiseConvolution::operator=(
-    DepthwiseConvolution&& operation) {
-  if (this != &operation) {
-    std::swap(weights_are_buffer_, operation.weights_are_buffer_);
-    std::swap(kernel_size_, operation.kernel_size_);
-    std::swap(stride_, operation.stride_);
-    std::swap(padding_, operation.padding_);
-    std::swap(dilation_, operation.dilation_);
-    std::swap(channel_multiplier_, operation.channel_multiplier_);
-    GPUOperation::operator=(std::move(operation));
-  }
-  return *this;
-}
-
-std::string DepthwiseConvolution::GenerateDepthwiseConvolutionCode(
-    const OperationDef& op_def, bool stride_correction, int channel_multiplier,
-    bool weights_are_buffer) {
+std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def,
+                                             bool stride_correction,
+                                             int channel_multiplier,
+                                             bool weights_are_buffer,
+                                             GPUOperation* op) {
   auto src_desc = op_def.src_tensors[0];
   src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
   if (op_def.IsBatchSupported()) {
     src_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_tensor", src_desc);
+  op->AddSrcTensor("src_tensor", src_desc);
 
   auto dst_desc = op_def.dst_tensors[0];
   if (op_def.IsBatchSupported()) {
     dst_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddDstTensor("dst_tensor", dst_desc);
-
-  args_.AddInt("kernel_size_x");
-  args_.AddInt("stride_x");
-  args_.AddInt("padding_x");
-  args_.AddInt("dilation_x");
-  args_.AddInt("kernel_size_y");
-  args_.AddInt("stride_y");
-  args_.AddInt("padding_y");
-  args_.AddInt("dilation_y");
-  if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    args_.AddInt("kernel_size_z");
-    args_.AddInt("stride_z");
-    args_.AddInt("padding_z");
-    args_.AddInt("dilation_z");
-  }
-  if (!IsSpecializedCase(channel_multiplier)) {
-    args_.AddInt("ch_multiplier");
-  }
+  op->AddDstTensor("dst_tensor", dst_desc);
 
   const auto src_tensor_type = op_def.src_tensors[0].storage_type;
 
@@ -171,14 +95,14 @@
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int X = get_global_id(0);\n";
-  c += "  int Y = get_global_id(1);\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    c += "  int linear_id_2 = get_global_id(2);\n";
-    c += "  int S = linear_id_2 / args.dst_tensor.Depth();\n";
-    c += "  int Z = linear_id_2 % args.dst_tensor.Depth();\n";
+    c += "  int linear_id_1 = get_global_id(1);\n";
+    c += "  int Y = linear_id_1 / args.dst_tensor.Depth();\n";
+    c += "  int Z = linear_id_1 % args.dst_tensor.Depth();\n";
   } else {
-    c += "  int S = get_global_id(2);\n";
+    c += "  int Y = get_global_id(1);\n";
   }
+  c += "  int S = get_global_id(2);\n";
   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
        "S >= args.dst_tensor.Slices()) { \n";
   c += "    return; \n";
@@ -186,11 +110,16 @@
   c += "  ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
   if (stride_correction) {
     c += "  int x_offseted = " +
-         GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
-                             "args.padding_x") +
+         GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x",
+                               "args.padding_x") +
          ";\n";
   } else {
-    c += "  int x_offseted = X * args.stride_x + args.padding_x;\n";
+    if (op_def.IsBatchSupported()) {
+      c += "  int x_offseted = X * args.stride_x + args.padding_x * "
+           "args.src_tensor.Batch();\n";
+    } else {
+      c += "  int x_offseted = X * args.stride_x + args.padding_x;\n";
+    }
   }
   c += "  int y_offseted = Y * args.stride_y + args.padding_y;\n";
   std::string weights_offset = "args.kernel_size_x * args.kernel_size_y";
@@ -218,7 +147,10 @@
     c += "    int y_c = y_offseted + ky * args.dilation_y;\n";
     c += "    bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
     c += "    for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
-    c += "      int x_c = x_offseted + kx * args.dilation_x;\n";
+    const std::string dilation_x =
+        op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
+                                  : "args.dilation_x";
+    c += "      int x_c = x_offseted + kx * " + dilation_x + ";\n";
     c += "      bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n";
     c += "      if (" + check + ") {\n";
     if (weights_are_buffer) {
@@ -252,7 +184,10 @@
     c += "  for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
     c += "    int y_c = y_offseted + ky * args.dilation_y;\n";
     c += "    for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
-    c += "      int x_c = x_offseted + kx * args.dilation_x;\n";
+    const std::string dilation_x =
+        op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
+                                  : "args.dilation_x";
+    c += "      int x_c = x_offseted + kx * " + dilation_x + ";\n";
     c += GetSrcValue(channel_multiplier, flat_coords);
     if (weights_are_buffer) {
       c += "      FLT4 f = args.weights.Read(fx_c);\n";
@@ -277,67 +212,80 @@
 
   return c;
 }
+}  // namespace
 
-absl::Status DepthwiseConvolution::BindArguments() {
-  RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
-  RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
-  RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
-  RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch()));
-  RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
-  RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
-  RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
-  RETURN_IF_ERROR(args_.SetInt("dilation_y", dilation_.y));
-  if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z));
-    RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z));
-    RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z));
-    RETURN_IF_ERROR(args_.SetInt("dilation_z", dilation_.z));
-  }
-  if (!IsSpecializedCase(channel_multiplier_)) {
-    RETURN_IF_ERROR(args_.SetInt("ch_multiplier", channel_multiplier_));
-  }
-  return absl::OkStatus();
-}
-
-int3 DepthwiseConvolution::GetGridSize() const {
-  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
-  const int grid_y = dst_[0]->Height();
-  const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
-  return int3(grid_x, grid_y, grid_z);
-}
-
-DepthwiseConvolution CreateDepthwiseConvolution(
+GPUOperation CreateDepthwiseConvolution2D(
     const DeviceInfo& device_info, const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr) {
   bool weights_are_buffer = device_info.IsMali();
-  DepthwiseConvolution result(definition, attr, weights_are_buffer);
-  result.UploadWeights(attr.weights);
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  op.args_.AddInt("dilation_x", attr.dilations.w);
+  op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+  op.args_.AddInt("dilation_y", attr.dilations.h);
+  if (!IsSpecializedCase(attr.weights.shape.o)) {
+    op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
+  }
+  const bool stride_correction =
+      definition.IsBatchSupported() && attr.strides.w != 1;
+  op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction,
+                                              attr.weights.shape.o,
+                                              weights_are_buffer, &op);
+  UploadWeightsForDWConv2D(attr.weights, weights_are_buffer,
+                           definition.precision, &op);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
 
   TensorLinearDescriptor desc;
   desc.storage_type = weights_are_buffer ? LinearStorageType::BUFFER
                                          : LinearStorageType::TEXTURE_2D;
   desc.element_type = definition.GetDataType();
   desc.UploadLinearData(attr.bias);
-  result.args_.AddObject(
+  op.args_.AddObject(
       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
-  return result;
+  return op;
 }
 
-DepthwiseConvolution CreateDepthwiseConvolution(
+GPUOperation CreateDepthwiseConvolution3D(
     const DeviceInfo& device_info, const OperationDef& definition,
     const DepthwiseConvolution3DAttributes& attr) {
   bool weights_are_buffer = device_info.IsMali();
-  DepthwiseConvolution result(definition, attr, weights_are_buffer);
-  result.UploadWeights(attr.weights);
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  op.args_.AddInt("dilation_x", attr.dilations.w);
+  op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+  op.args_.AddInt("dilation_y", attr.dilations.h);
+  op.args_.AddInt("kernel_size_z", attr.weights.shape.d);
+  op.args_.AddInt("stride_z", attr.strides.d);
+  op.args_.AddInt("padding_z", -attr.padding.prepended.d);
+  op.args_.AddInt("dilation_z", attr.dilations.d);
+  if (!IsSpecializedCase(attr.weights.shape.o)) {
+    op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
+  }
+  const bool stride_correction =
+      definition.IsBatchSupported() && attr.strides.w != 1;
+  op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction,
+                                              attr.weights.shape.o,
+                                              weights_are_buffer, &op);
+  UploadWeightsForDWConv3D(attr.weights, weights_are_buffer,
+                           definition.precision, &op);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
 
   TensorLinearDescriptor desc;
   desc.storage_type = weights_are_buffer ? LinearStorageType::BUFFER
                                          : LinearStorageType::TEXTURE_2D;
   desc.element_type = definition.GetDataType();
   desc.UploadLinearData(attr.bias);
-  result.args_.AddObject(
+  op.args_.AddObject(
       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
-  return result;
+  return op;
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h
index afa6375..a5a4f5b 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h
@@ -35,102 +35,9 @@
 namespace gpu {
 namespace cl {
 
-class DepthwiseConvolution : public GPUOperation {
- public:
-  DepthwiseConvolution() = default;
-  absl::Status BindArguments() override;
-  int3 GetGridSize() const override;
-
-  // Move only
-  DepthwiseConvolution(DepthwiseConvolution&& operation);
-  DepthwiseConvolution& operator=(DepthwiseConvolution&& operation);
-  DepthwiseConvolution(const DepthwiseConvolution&) = delete;
-  DepthwiseConvolution& operator=(const DepthwiseConvolution&) = delete;
-
- private:
-  friend DepthwiseConvolution CreateDepthwiseConvolution(
-      const DeviceInfo& device_info, const OperationDef& definition,
-      const DepthwiseConvolution2DAttributes& attr);
-  friend DepthwiseConvolution CreateDepthwiseConvolution(
-      const DeviceInfo& device_info, const OperationDef& definition,
-      const DepthwiseConvolution3DAttributes& attr);
-  DepthwiseConvolution(const OperationDef& definition,
-                       const DepthwiseConvolution2DAttributes& attr,
-                       bool weights_are_buffer);
-  DepthwiseConvolution(const OperationDef& definition,
-                       const DepthwiseConvolution3DAttributes& attr,
-                       bool weights_are_buffer);
-
-  template <DataType T>
-  void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights);
-
-  template <DataType S, typename T>
-  void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
-                            absl::Span<T> dst);
-
-  template <DataType T>
-  void UploadWeights(const tflite::gpu::Tensor<OHWDI, T>& weights);
-
-  template <DataType S, typename T>
-  void RearrangeWeightsData(const tflite::gpu::Tensor<OHWDI, S>& weights,
-                            absl::Span<T> dst);
-
-  std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def,
-                                               bool stride_correction,
-                                               int channel_multiplier,
-                                               bool weights_are_buffer);
-
-  bool weights_are_buffer_;
-
-  int4 kernel_size_;
-  int4 stride_;
-  int4 padding_;
-  int4 dilation_;
-  int channel_multiplier_;
-};
-
-template <DataType T>
-void DepthwiseConvolution::UploadWeights(
-    const tflite::gpu::Tensor<OHWI, T>& weights) {
-  const int dst_channels = weights.shape.i * weights.shape.o;
-  const int dst_slices = DivideRoundUp(dst_channels, 4);
-  const int kernel_x = weights.shape.w;
-  const int kernel_y = weights.shape.h;
-
-  const int elements_count = kernel_x * kernel_y * dst_slices;
-
-  const bool fp32_weights = definition_.precision == CalculationsPrecision::F32;
-  const int float4_size = fp32_weights ? 16 : 8;
-
-  std::vector<uint8_t> data(float4_size * elements_count);
-
-  if (fp32_weights) {
-    float4* ptr = reinterpret_cast<float4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
-  } else {
-    half4* ptr = reinterpret_cast<half4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
-  }
-
-  if (weights_are_buffer_) {
-    BufferDescriptor desc;
-    desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
-    desc.element_size = 4;
-    desc.size = float4_size * elements_count;
-    desc.data = std::move(data);
-    args_.AddObject("weights", absl::make_unique<BufferDescriptor>(desc));
-  } else {
-    Texture2DDescriptor desc;
-    desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
-    desc.size = int2(kernel_x * kernel_y, dst_slices);
-    desc.data = std::move(data);
-    args_.AddObject("weights", absl::make_unique<Texture2DDescriptor>(desc));
-  }
-}
-
 template <DataType S, typename T>
-void DepthwiseConvolution::RearrangeWeightsData(
-    const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
+void RearrangeWeightsForDWConv2D(const tflite::gpu::Tensor<OHWI, S>& weights,
+                                 absl::Span<T> dst) {
   const int dst_channels = weights.shape.i * weights.shape.o;
   const int dst_depth = DivideRoundUp(dst_channels, 4);
   const int kernel_x = weights.shape.w;
@@ -158,50 +65,50 @@
 }
 
 template <DataType T>
-void DepthwiseConvolution::UploadWeights(
-    const tflite::gpu::Tensor<OHWDI, T>& weights) {
+void UploadWeightsForDWConv2D(const tflite::gpu::Tensor<OHWI, T>& weights,
+                              bool weights_are_buffer,
+                              CalculationsPrecision precision,
+                              GPUOperation* op) {
   const int dst_channels = weights.shape.i * weights.shape.o;
   const int dst_slices = DivideRoundUp(dst_channels, 4);
   const int kernel_x = weights.shape.w;
   const int kernel_y = weights.shape.h;
-  const int kernel_z = weights.shape.d;
 
-  const int elements_count = kernel_x * kernel_y * kernel_z * dst_slices;
+  const int elements_count = kernel_x * kernel_y * dst_slices;
 
-  const bool fp32_weights = definition_.precision == CalculationsPrecision::F32;
+  const bool fp32_weights = precision == CalculationsPrecision::F32;
   const int float4_size = fp32_weights ? 16 : 8;
 
   std::vector<uint8_t> data(float4_size * elements_count);
 
   if (fp32_weights) {
     float4* ptr = reinterpret_cast<float4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
+    RearrangeWeightsForDWConv2D(weights, absl::MakeSpan(ptr, elements_count));
   } else {
     half4* ptr = reinterpret_cast<half4*>(data.data());
-    RearrangeWeightsData(weights, absl::MakeSpan(ptr, elements_count));
+    RearrangeWeightsForDWConv2D(weights, absl::MakeSpan(ptr, elements_count));
   }
 
-  if (weights_are_buffer_) {
+  if (weights_are_buffer) {
     BufferDescriptor desc;
     desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
     desc.element_size = 4;
     desc.size = float4_size * elements_count;
     desc.data = std::move(data);
-    args_.AddObject("weights",
-                    absl::make_unique<BufferDescriptor>(std::move(desc)));
+    op->args_.AddObject("weights", absl::make_unique<BufferDescriptor>(desc));
   } else {
     Texture2DDescriptor desc;
     desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
-    desc.size = int2(kernel_x * kernel_y * kernel_z, dst_slices);
+    desc.size = int2(kernel_x * kernel_y, dst_slices);
     desc.data = std::move(data);
-    args_.AddObject("weights",
-                    absl::make_unique<Texture2DDescriptor>(std::move(desc)));
+    op->args_.AddObject("weights",
+                        absl::make_unique<Texture2DDescriptor>(desc));
   }
 }
 
 template <DataType S, typename T>
-void DepthwiseConvolution::RearrangeWeightsData(
-    const tflite::gpu::Tensor<OHWDI, S>& weights, absl::Span<T> dst) {
+void RearrangeWeightsForDWConv3D(const tflite::gpu::Tensor<OHWDI, S>& weights,
+                                 absl::Span<T> dst) {
   const int dst_channels = weights.shape.i * weights.shape.o;
   const int dst_slices = DivideRoundUp(dst_channels, 4);
   const int kernel_x = weights.shape.w;
@@ -231,11 +138,55 @@
   }
 }
 
-DepthwiseConvolution CreateDepthwiseConvolution(
+template <DataType T>
+void UploadWeightsForDWConv3D(const tflite::gpu::Tensor<OHWDI, T>& weights,
+                              bool weights_are_buffer,
+                              CalculationsPrecision precision,
+                              GPUOperation* op) {
+  const int dst_channels = weights.shape.i * weights.shape.o;
+  const int dst_slices = DivideRoundUp(dst_channels, 4);
+  const int kernel_x = weights.shape.w;
+  const int kernel_y = weights.shape.h;
+  const int kernel_z = weights.shape.d;
+
+  const int elements_count = kernel_x * kernel_y * kernel_z * dst_slices;
+
+  const bool fp32_weights = precision == CalculationsPrecision::F32;
+  const int float4_size = fp32_weights ? 16 : 8;
+
+  std::vector<uint8_t> data(float4_size * elements_count);
+
+  if (fp32_weights) {
+    float4* ptr = reinterpret_cast<float4*>(data.data());
+    RearrangeWeightsForDWConv3D(weights, absl::MakeSpan(ptr, elements_count));
+  } else {
+    half4* ptr = reinterpret_cast<half4*>(data.data());
+    RearrangeWeightsForDWConv3D(weights, absl::MakeSpan(ptr, elements_count));
+  }
+
+  if (weights_are_buffer) {
+    BufferDescriptor desc;
+    desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+    desc.element_size = 4;
+    desc.size = float4_size * elements_count;
+    desc.data = std::move(data);
+    op->args_.AddObject("weights",
+                        absl::make_unique<BufferDescriptor>(std::move(desc)));
+  } else {
+    Texture2DDescriptor desc;
+    desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+    desc.size = int2(kernel_x * kernel_y * kernel_z, dst_slices);
+    desc.data = std::move(data);
+    op->args_.AddObject(
+        "weights", absl::make_unique<Texture2DDescriptor>(std::move(desc)));
+  }
+}
+
+GPUOperation CreateDepthwiseConvolution2D(
     const DeviceInfo& device_info, const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr);
 
-DepthwiseConvolution CreateDepthwiseConvolution(
+GPUOperation CreateDepthwiseConvolution3D(
     const DeviceInfo& device_info, const OperationDef& definition,
     const DepthwiseConvolution3DAttributes& attr);
 
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc
index 5c3e596..eb43c0c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv_test.cc
@@ -55,7 +55,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      DepthwiseConvolution operation = CreateDepthwiseConvolution(
+      GPUOperation operation = CreateDepthwiseConvolution2D(
           creation_context_.GetDeviceInfo(), op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 2, 2), &dst_tensor));
@@ -90,7 +90,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      DepthwiseConvolution operation = CreateDepthwiseConvolution(
+      GPUOperation operation = CreateDepthwiseConvolution2D(
           creation_context_.GetDeviceInfo(), op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 2, 2), &dst_tensor));
@@ -126,7 +126,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      DepthwiseConvolution operation = CreateDepthwiseConvolution(
+      GPUOperation operation = CreateDepthwiseConvolution2D(
           creation_context_.GetDeviceInfo(), op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 2, 4), &dst_tensor));
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc
index 22a76c3..3203ec3 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/elementwise.cc
@@ -42,10 +42,10 @@
       result = "\n";
       break;
     case OperationType::ELU:
-      result = "$0.x = $0.x < (FLT)(0.0f) ? exp($0.x) - (FLT)(1.0f) : $0.x;\n";
-      result += "$0.y = $0.y < (FLT)(0.0f) ? exp($0.y) - (FLT)(1.0f) : $0.y;\n";
-      result += "$0.z = $0.z < (FLT)(0.0f) ? exp($0.z) - (FLT)(1.0f) : $0.z;\n";
-      result += "$0.w = $0.w < (FLT)(0.0f) ? exp($0.w) - (FLT)(1.0f) : $0.w;\n";
+      result = "$0.x = $0.x < (FLT)(0.0f) ? expm1($0.x) : $0.x;\n";
+      result += "$0.y = $0.y < (FLT)(0.0f) ? expm1($0.y) : $0.y;\n";
+      result += "$0.z = $0.z < (FLT)(0.0f) ? expm1($0.z) : $0.z;\n";
+      result += "$0.w = $0.w < (FLT)(0.0f) ? expm1($0.w) : $0.w;\n";
       break;
     case OperationType::EXP:
       result = "$0 = exp($0);\n";
@@ -59,7 +59,7 @@
       result = "$0 = log($0);\n";
       break;
     case OperationType::RSQRT:
-      result = "$0 = (FLT4)(1.0f) / sqrt($0);\n";
+      result = "$0 = rsqrt($0);\n";
       break;
     case OperationType::SIGMOID:
       if (precision != CalculationsPrecision::F32) {
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
index f9d6ec7..b34b8e3 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc
@@ -213,7 +213,6 @@
     RETURN_IF_ERROR(args_.TransformToCLCode(
         creation_context.device->info_,
         {{dst_tensors_names_[0], elementwise_code_}}, &code));
-    code = absl::Substitute(code, args_.GetListOfArgs());
     RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
         code, "main_function", *creation_context.context,
         *creation_context.device, &kernel_));
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc
index 0fc5e49..c98ac36 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.cc
@@ -24,33 +24,14 @@
 namespace tflite {
 namespace gpu {
 namespace cl {
-
-LSTM::LSTM(const OperationDef& definition, const DeviceInfo& device_info)
-    : GPUOperation(definition) {
-  code_ = GetLSTMCode(definition_, device_info);
-}
-
-LSTM::LSTM(LSTM&& kernel) : GPUOperation(std::move(kernel)) {}
-
-LSTM& LSTM::operator=(LSTM&& kernel) {
-  if (this != &kernel) {
-    GPUOperation::operator=(std::move(kernel));
-  }
-  return *this;
-}
-
-std::string LSTM::GetLSTMCode(const OperationDef& op_def,
-                              const DeviceInfo& device_info) {
-  AddSrcTensor("intermediate", op_def.src_tensors[0]);
-  AddSrcTensor("prev_state", op_def.src_tensors[1]);
-  AddDstTensor("new_state", op_def.dst_tensors[0]);
-  AddDstTensor("activation", op_def.dst_tensors[1]);
-
+namespace {
+std::string GetLSTMCode(const OperationDef& op_def,
+                        const DeviceInfo& device_info) {
   std::string c = GetCommonDefines(op_def.precision);
   c += "__kernel void main_function(\n";
   c += "$0) {\n";
   c += "  int B = get_global_id(0);\n";
-  c += "  int Z = get_global_id(1);\n";
+  c += "  int Z = get_global_id(2);\n";
   c += "  if (Z >= args.activation.Slices() || B >= args.activation.Batch()) "
        "return;\n";
   c += "  FLT4 prev_st = args.prev_state.Read(0, 0, Z, B);\n";
@@ -105,15 +86,18 @@
   return c;
 }
 
-int3 LSTM::GetGridSize() const {
-  const int grid_x = dst_[0]->Batch();
-  const int grid_y = dst_[0]->Slices();
-  const int grid_z = 1;
-  return int3(grid_x, grid_y, grid_z);
-}
+}  // namespace
 
-LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info) {
-  return LSTM(definition, device_info);
+GPUOperation CreateLSTM(const OperationDef& definition,
+                        const DeviceInfo& device_info) {
+  GPUOperation op(definition);
+  op.AddSrcTensor("intermediate", definition.src_tensors[0]);
+  op.AddSrcTensor("prev_state", definition.src_tensors[1]);
+  op.AddDstTensor("new_state", definition.dst_tensors[0]);
+  op.AddDstTensor("activation", definition.dst_tensors[1]);
+  op.code_ = GetLSTMCode(definition, device_info);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  return op;
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h
index 91bfd22..5d827d4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm.h
@@ -25,23 +25,8 @@
 namespace gpu {
 namespace cl {
 
-class LSTM : public GPUOperation {
- public:
-  LSTM(const OperationDef& definition, const DeviceInfo& device_info);
-  int3 GetGridSize() const override;
-
-  // Move only
-  LSTM(LSTM&& kernel);
-  LSTM& operator=(LSTM&& kernel);
-  LSTM(const LSTM&) = delete;
-  LSTM& operator=(const LSTM&) = delete;
-
- private:
-  std::string GetLSTMCode(const OperationDef& op_def,
-                          const DeviceInfo& device_info);
-};
-
-LSTM CreateLSTM(const OperationDef& definition, const DeviceInfo& device_info);
+GPUOperation CreateLSTM(const OperationDef& definition,
+                        const DeviceInfo& device_info);
 
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/lstm_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/lstm_test.cc
index 52e9b4c..8982d99 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/lstm_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/lstm_test.cc
@@ -67,7 +67,7 @@
       op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
       TensorFloat32 new_state;
       TensorFloat32 new_activ;
-      LSTM operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
+      GPUOperation operation = CreateLSTM(op_def, env_.GetDevicePtr()->info_);
       ASSERT_OK(ExecuteGPUOperation(
           {src_tensor, prev_state}, creation_context_, &operation,
           {BHWC(1, 1, 1, 4), BHWC(1, 1, 1, 4)}, {&new_state, &new_activ}));
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
index 97ee487..0bea5e4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.cc
@@ -23,76 +23,26 @@
 namespace tflite {
 namespace gpu {
 namespace cl {
-
-MaxUnpooling::MaxUnpooling(const OperationDef& definition,
-                           const MaxUnpooling2DAttributes& attr)
-    : GPUOperation(definition),
-      stride_(attr.strides.w, attr.strides.h, 0, 0),
-      padding_(attr.padding.appended.w, attr.padding.appended.h, 0, 0),
-      kernel_size_(attr.kernel.w, attr.kernel.h, 0, 0) {
-  code_ = GetMaxUnpoolingKernelCode(definition_);
-}
-
-MaxUnpooling::MaxUnpooling(const OperationDef& definition,
-                           const MaxUnpooling3DAttributes& attr)
-    : GPUOperation(definition),
-      stride_(attr.strides.w, attr.strides.h, attr.strides.d, 0),
-      padding_(attr.padding.appended.w, attr.padding.appended.h,
-               attr.padding.appended.d, 0),
-      kernel_size_(attr.kernel.w, attr.kernel.h, attr.kernel.d, 0) {
-  code_ = GetMaxUnpoolingKernelCode(definition_);
-}
-
-MaxUnpooling::MaxUnpooling(MaxUnpooling&& kernel)
-    : GPUOperation(std::move(kernel)),
-      stride_(kernel.stride_),
-      padding_(kernel.padding_),
-      kernel_size_(kernel.kernel_size_) {}
-
-MaxUnpooling& MaxUnpooling::operator=(MaxUnpooling&& kernel) {
-  if (this != &kernel) {
-    std::swap(stride_, kernel.stride_);
-    std::swap(padding_, kernel.padding_);
-    std::swap(kernel_size_, kernel.kernel_size_);
-    GPUOperation::operator=(std::move(kernel));
-  }
-  return *this;
-}
-
-std::string MaxUnpooling::GetMaxUnpoolingKernelCode(
-    const OperationDef& op_def) {
+namespace {
+std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def,
+                                      GPUOperation* op) {
   auto src_desc = op_def.src_tensors[0];
   src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
   if (op_def.IsBatchSupported()) {
     src_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_tensor", src_desc);
+  op->AddSrcTensor("src_tensor", src_desc);
   auto src_ind_desc = op_def.src_tensors[1];
   src_ind_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
   if (op_def.IsBatchSupported()) {
     src_ind_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_indices", src_ind_desc);
+  op->AddSrcTensor("src_indices", src_ind_desc);
   auto dst_desc = op_def.dst_tensors[0];
   if (op_def.IsBatchSupported()) {
     dst_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddDstTensor("dst_tensor", dst_desc);
-  if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
-    args_.AddInt("kernel_size_x");
-    args_.AddInt("padding_x");
-    args_.AddInt("stride_x");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
-    args_.AddInt("kernel_size_y");
-    args_.AddInt("padding_y");
-    args_.AddInt("stride_y");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    args_.AddInt("kernel_size_z");
-    args_.AddInt("padding_z");
-    args_.AddInt("stride_z");
-  }
+  op->AddDstTensor("dst_tensor", dst_desc);
 
   std::string c = GetCommonDefines(op_def.precision);
   c += "__kernel void main_function(\n";
@@ -115,7 +65,8 @@
     c += "  int linear_id_0 = get_global_id(0);\n";
     c += "  int X0 = linear_id_0 / args.dst_tensor.Batch();\n";
     c += "  int B = linear_id_0 % args.dst_tensor.Batch();\n";
-    c += "  int src_x0 = (X0 + args.padding_x) / args.stride_x;\n";
+    c += "  int src_x0 = (X0 + args.padding_x * args.dst_tensor.Batch()) / "
+         "args.stride_x;\n";
     c += "  int src_x = src_x0 * args.dst_tensor.Batch() + B;\n";
   } else {
     c += "  int src_x = (X + args.padding_x) / args.stride_x;\n";
@@ -145,7 +96,8 @@
         "  int4 ind = convert_int4(args.src_indices.Read(" + src_args + "));\n";
   }
   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
-    c += "  int t_x = X0 - (src_x0 * args.stride_x - args.padding_x);\n";
+    c += "  int t_x = X0 - (src_x0 * args.stride_x - args.padding_x * "
+         "args.dst_tensor.Batch());\n";
   } else {
     c += "  int t_x = X - (src_x * args.stride_x - args.padding_x);\n";
   }
@@ -172,41 +124,37 @@
 
   return c;
 }
+}  // namespace
 
-absl::Status MaxUnpooling::BindArguments() {
-  if (definition_.dst_tensors[0].HasAxis(Axis::WIDTH)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
-    RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
-  }
-  if (definition_.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
-    RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
-  }
-  if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z));
-    RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z));
-  }
-  return absl::OkStatus();
-}
-
-int3 MaxUnpooling::GetGridSize() const {
-  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
-  const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
-  const int grid_z = dst_[0]->Slices();
-  return int3(grid_x, grid_y, grid_z);
-}
-
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling2DAttributes& attr) {
-  return MaxUnpooling(definition, attr);
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.kernel.w);
+  op.args_.AddInt("padding_x", attr.padding.appended.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("kernel_size_y", attr.kernel.h);
+  op.args_.AddInt("padding_y", attr.padding.appended.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.code_ = GetMaxUnpoolingKernelCode(definition, &op);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  return op;
 }
 
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling3DAttributes& attr) {
-  return MaxUnpooling(definition, attr);
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.kernel.w);
+  op.args_.AddInt("padding_x", attr.padding.appended.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("kernel_size_y", attr.kernel.h);
+  op.args_.AddInt("padding_y", attr.padding.appended.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.args_.AddInt("kernel_size_z", attr.kernel.d);
+  op.args_.AddInt("padding_z", attr.padding.appended.d);
+  op.args_.AddInt("stride_z", attr.strides.d);
+  op.code_ = GetMaxUnpoolingKernelCode(definition, &op);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  return op;
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
index 0b1420a..c1b6cbf 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h
@@ -25,34 +25,10 @@
 namespace gpu {
 namespace cl {
 
-class MaxUnpooling : public GPUOperation {
- public:
-  MaxUnpooling(const OperationDef& definition,
-               const MaxUnpooling2DAttributes& attr);
-  MaxUnpooling(const OperationDef& definition,
-               const MaxUnpooling3DAttributes& attr);
-
-  absl::Status BindArguments() override;
-  int3 GetGridSize() const override;
-
-  // Move only
-  MaxUnpooling(MaxUnpooling&& kernel);
-  MaxUnpooling& operator=(MaxUnpooling&& kernel);
-  MaxUnpooling(const MaxUnpooling&) = delete;
-  MaxUnpooling& operator=(const MaxUnpooling&) = delete;
-
- private:
-  std::string GetMaxUnpoolingKernelCode(const OperationDef& op_def);
-
-  int4 stride_;
-  int4 padding_;
-  int4 kernel_size_;
-};
-
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling2DAttributes& attr);
 
-MaxUnpooling CreateMaxUnpooling(const OperationDef& definition,
+GPUOperation CreateMaxUnpooling(const OperationDef& definition,
                                 const MaxUnpooling3DAttributes& attr);
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc
index c03cb4f..654b389 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling_test.cc
@@ -55,7 +55,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      MaxUnpooling operation = CreateMaxUnpooling(op_def, attr);
+      GPUOperation operation = CreateMaxUnpooling(op_def, attr);
       ASSERT_OK(ExecuteGPUOperation({src_tensor, src_ind_tensor},
                                     creation_context_, &operation,
                                     BHWC(1, 4, 4, 1), &dst_tensor));
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc
index a89d712..c36dacd 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.cc
@@ -50,13 +50,13 @@
 static inline float local_reduce(float input, __local float* tmp) {
   const int local_id = get_local_id(0);
   tmp[local_id] = input;
-  mem_fence(CLK_LOCAL_MEM_FENCE);
+  barrier(CLK_LOCAL_MEM_FENCE);
   int reduction_size = get_local_size(0) / 2;
   while (reduction_size > 0) {
     if (local_id < reduction_size) {
       tmp[local_id] += tmp[local_id + reduction_size];
     }
-    mem_fence(CLK_LOCAL_MEM_FENCE);
+    barrier(CLK_LOCAL_MEM_FENCE);
     reduction_size /=  2;
   }
   return tmp[0];
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc
index 57f0525..8ff34be 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization_test.cc
@@ -54,7 +54,8 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
       TensorFloat32 dst_tensor;
-      auto operation = CreateMeanStdDevNormalization(op_def);
+      auto operation =
+          CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_);
       ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation,
                                     BHWC(1, 1, 1, 4), &dst_tensor));
 
@@ -72,18 +73,19 @@
   }
 }
 
+// note: 100.01 is not representable in FP16 (is in FP32), so use 101.0 instead.
 INSTANTIATE_TEST_SUITE_P(
     uKernels, MeanStddevNormalizationTest,
     testing::Values(
         std::make_tuple(0.0f, 0.0f, 0.0f),         // zero mean, zero variance
-        std::make_tuple(0.0f, 0.01f, 2.53e-5f),    // zero mean, small variance
-        std::make_tuple(0.0f, 100.0f, 1.20e-7f),   // zero mean, large variance
+        std::make_tuple(0.0f, 0.01f, 2.63e-4f),    // zero mean, small variance
+        std::make_tuple(0.0f, 100.0f, 2.63e-4f),   // zero mean, large variance
         std::make_tuple(0.01f, 0.0f, 0.0f),        // small mean, zero variance
-        std::make_tuple(0.01f, 0.01f, 2.53e-5f),   // small mean, small variance
-        std::make_tuple(0.01f, 100.0f, 1.20e-7f),  // small mean, large variance
+        std::make_tuple(0.01f, 0.01f, 3.57e-4f),   // small mean, small variance
+        std::make_tuple(1.0f, 100.0f, 2.63e-4f),   // small mean, large variance
         std::make_tuple(100.0f, 0.0f, 0.0f),       // large mean, zero variance
-        std::make_tuple(100.0f, 0.01f, 1.81e-4f),  // large mean, small variance
-        std::make_tuple(100.0f, 100.0f, 1.20e-7f)  // large mean, large variance
+        std::make_tuple(100.0f, 1.0f, 2.63e-4f),   // large mean, small variance
+        std::make_tuple(100.0f, 100.0f, 2.63e-4f)  // large mean, large variance
         ));
 
 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(MeanStddevNormalizationTest);
@@ -92,15 +94,15 @@
   TensorFloat32 src_tensor;
   src_tensor.shape = BHWC(9, 1, 1, 4);
   src_tensor.data = {
-      0.0f,     0.0f,    0.0f,    0.0f,     // zero mean, zero variance
-      -0.02f,   -0.01f,  0.01f,   0.02f,    // zero mean, small variance
-      -200.0f,  -100.0f, 100.0f,  200.0f,   // zero mean, large variance
-      0.01f,    0.01f,   0.01f,   0.01f,    // small mean, zero variance
-      -0.01f,   0.0f,    0.02f,   0.03f,    // small mean, small variance
-      -199.99f, -99.99f, 100.01f, 200.01f,  // small mean, large variance
-      100.0f,   100.0f,  100.0f,  100.0f,   // large mean, zero variance
-      99.98f,   99.99f,  100.01f, 100.02f,  // large mean, small variance
-      -100.0f,  0.0f,    200.0f,  300.0f,   // large mean, large variance
+      0.0f,    0.0f,    0.0f,   0.0f,    // zero mean, zero variance
+      -0.02f,  -0.01f,  0.01f,  0.02f,   // zero mean, small variance
+      -200.0f, -100.0f, 100.0f, 200.0f,  // zero mean, large variance
+      0.01f,   0.01f,   0.01f,  0.01f,   // small mean, zero variance
+      -0.01f,  0.0f,    0.02f,  0.03f,   // small mean, small variance
+      -199.0f, -99.0f,  101.0f, 201.0f,  // small mean, large variance
+      100.0f,  100.0f,  100.0f, 100.0f,  // large mean, zero variance
+      98.0f,   99.0f,   101.0f, 102.0f,  // large mean, small variance
+      -100.0f, 0.0f,    200.0f, 300.0f,  // large mean, large variance
   };
   for (auto storage : env_.GetSupportedStorages()) {
     for (auto precision : env_.GetSupportedPrecisions()) {
@@ -110,7 +112,8 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
       TensorFloat32 dst_tensor;
-      auto operation = CreateMeanStdDevNormalization(op_def);
+      auto operation =
+          CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_);
       ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation,
                                     BHWC(9, 1, 1, 4), &dst_tensor));
 
@@ -128,7 +131,7 @@
           -ksqrt16, -ksqrt04, ksqrt04, ksqrt16,  // large mean, large variance
       };
       EXPECT_THAT(dst_tensor.data,
-                  Pointwise(FloatNear(1.81e-4f), expected_output));
+                  Pointwise(FloatNear(3.57e-4f), expected_output));
     }
   }
 }
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc
index fb077fe..af16461 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc
@@ -23,78 +23,21 @@
 namespace tflite {
 namespace gpu {
 namespace cl {
-
-Pooling::Pooling(const OperationDef& definition,
-                 const Pooling2DAttributes& attr)
-    : GPUOperation(definition),
-      stride_(attr.strides.w, attr.strides.h, 0, 0),
-      padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
-      kernel_size_(attr.kernel.w, attr.kernel.h, 0, 0),
-      type_(attr.type),
-      output_indices_(attr.output_indices) {
-  GenerateCode();
-}
-
-Pooling::Pooling(const OperationDef& definition,
-                 const Pooling3DAttributes& attr)
-    : GPUOperation(definition),
-      stride_(attr.strides.w, attr.strides.h, attr.strides.d, 0),
-      padding_(-attr.padding.prepended.w, -attr.padding.prepended.h,
-               -attr.padding.prepended.d, 0),
-      kernel_size_(attr.kernel.w, attr.kernel.h, attr.kernel.d, 0),
-      type_(attr.type),
-      output_indices_(attr.output_indices) {
-  GenerateCode();
-}
-
-Pooling::Pooling(Pooling&& kernel)
-    : GPUOperation(std::move(kernel)),
-      stride_(kernel.stride_),
-      padding_(kernel.padding_),
-      kernel_size_(kernel.kernel_size_),
-      type_(kernel.type_),
-      output_indices_(kernel.output_indices_) {}
-
-Pooling& Pooling::operator=(Pooling&& kernel) {
-  if (this != &kernel) {
-    std::swap(stride_, kernel.stride_);
-    std::swap(padding_, kernel.padding_);
-    std::swap(kernel_size_, kernel.kernel_size_);
-    std::swap(type_, kernel.type_);
-    std::swap(output_indices_, kernel.output_indices_);
-    GPUOperation::operator=(std::move(kernel));
-  }
-  return *this;
-}
-
-std::string Pooling::GetAveragePoolingKernelCode(const OperationDef& op_def,
-                                                 bool stride_correction) {
+namespace {
+std::string GetAveragePoolingKernelCode(const OperationDef& op_def,
+                                        bool stride_correction,
+                                        GPUOperation* op) {
   auto src_desc = op_def.src_tensors[0];
   src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
   if (op_def.IsBatchSupported()) {
     src_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_tensor", src_desc);
+  op->AddSrcTensor("src_tensor", src_desc);
   auto dst_desc = op_def.dst_tensors[0];
   if (op_def.IsBatchSupported()) {
     dst_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddDstTensor("dst_tensor", dst_desc);
-  if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
-    args_.AddInt("kernel_size_x");
-    args_.AddInt("padding_x");
-    args_.AddInt("stride_x");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
-    args_.AddInt("kernel_size_y");
-    args_.AddInt("padding_y");
-    args_.AddInt("stride_y");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    args_.AddInt("kernel_size_z");
-    args_.AddInt("padding_z");
-    args_.AddInt("stride_z");
-  }
+  op->AddDstTensor("dst_tensor", dst_desc);
 
   std::map<Axis, std::string> axis_to_src_coord = {
       {Axis::WIDTH, "x_c"},  {Axis::HEIGHT, "y_c"}, {Axis::DEPTH, "d_c"},
@@ -149,11 +92,16 @@
   c += "  float window_size = 0.0;\n";
   if (stride_correction) {
     c += "  int xs = " +
-         GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
-                             "args.padding_x") +
+         GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x",
+                               "args.padding_x") +
          ";\n";
   } else {
-    c += "  int xs = X * args.stride_x + args.padding_x;\n";
+    if (op_def.IsBatchSupported()) {
+      c += "  int xs = X * args.stride_x + args.padding_x * "
+           "args.src_tensor.Batch();\n";
+    } else {
+      c += "  int xs = X * args.stride_x + args.padding_x;\n";
+    }
   }
   c += "  int ys = Y * args.stride_y + args.padding_y;\n";
   if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
@@ -195,40 +143,25 @@
   return c;
 }
 
-std::string Pooling::GetMaxPoolingKernelCode(const OperationDef& op_def,
-                                             bool stride_correction,
-                                             bool output_indices) {
+std::string GetMaxPoolingKernelCode(const OperationDef& op_def,
+                                    bool stride_correction, bool output_indices,
+                                    GPUOperation* op) {
   auto src_desc = op_def.src_tensors[0];
   if (op_def.IsBatchSupported()) {
     src_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddSrcTensor("src_tensor", src_desc);
+  op->AddSrcTensor("src_tensor", src_desc);
   auto dst_desc = op_def.dst_tensors[0];
   if (op_def.IsBatchSupported()) {
     dst_desc.SetStateVar("BatchedWidth", "true");
   }
-  AddDstTensor("dst_tensor", dst_desc);
+  op->AddDstTensor("dst_tensor", dst_desc);
   if (output_indices) {
     auto dst_ind_desc = op_def.dst_tensors[1];
     if (op_def.IsBatchSupported()) {
       dst_ind_desc.SetStateVar("BatchedWidth", "true");
     }
-    AddDstTensor("dst_indices", dst_ind_desc);
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
-    args_.AddInt("kernel_size_x");
-    args_.AddInt("padding_x");
-    args_.AddInt("stride_x");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
-    args_.AddInt("kernel_size_y");
-    args_.AddInt("padding_y");
-    args_.AddInt("stride_y");
-  }
-  if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    args_.AddInt("kernel_size_z");
-    args_.AddInt("padding_z");
-    args_.AddInt("stride_z");
+    op->AddDstTensor("dst_indices", dst_ind_desc);
   }
 
   std::map<Axis, std::string> axis_to_src_coord = {
@@ -282,11 +215,16 @@
   }
   if (stride_correction) {
     c += "  int xs = " +
-         GetXStrideCorrected("X", "args.src_tensor.Batch()", "args.stride_x",
-                             "args.padding_x") +
+         GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x",
+                               "args.padding_x") +
          ";\n";
   } else {
-    c += "  int xs = X * args.stride_x + args.padding_x;\n";
+    if (op_def.IsBatchSupported()) {
+      c += "  int xs = X * args.stride_x + args.padding_x * "
+           "args.src_tensor.Batch();\n";
+    } else {
+      c += "  int xs = X * args.stride_x + args.padding_x;\n";
+    }
   }
   c += "  int ys = Y * args.stride_y + args.padding_y;\n";
   c += "  for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
@@ -346,52 +284,51 @@
 
   return c;
 }
+}  // namespace
 
-void Pooling::GenerateCode() {
+GPUOperation CreatePooling(const OperationDef& definition,
+                           const Pooling2DAttributes& attr) {
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.kernel.w);
+  op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("kernel_size_y", attr.kernel.h);
+  op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
   const bool stride_correction =
-      definition_.IsBatchSupported() && stride_.x != 1;
-  if (type_ == PoolingType::AVERAGE) {
-    code_ = GetAveragePoolingKernelCode(definition_, stride_correction);
-  } else if (type_ == PoolingType::MAX) {
-    code_ = GetMaxPoolingKernelCode(definition_, stride_correction,
-                                    output_indices_);
+      definition.IsBatchSupported() && attr.strides.w != 1;
+  if (attr.type == PoolingType::AVERAGE) {
+    op.code_ = GetAveragePoolingKernelCode(definition, stride_correction, &op);
+  } else if (attr.type == PoolingType::MAX) {
+    op.code_ = GetMaxPoolingKernelCode(definition, stride_correction,
+                                       attr.output_indices, &op);
   }
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  return op;
 }
 
-absl::Status Pooling::BindArguments() {
-  if (definition_.dst_tensors[0].HasAxis(Axis::WIDTH)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
-    RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
+GPUOperation CreatePooling(const OperationDef& definition,
+                           const Pooling3DAttributes& attr) {
+  GPUOperation op(definition);
+  op.args_.AddInt("kernel_size_x", attr.kernel.w);
+  op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  op.args_.AddInt("stride_x", attr.strides.w);
+  op.args_.AddInt("kernel_size_y", attr.kernel.h);
+  op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+  op.args_.AddInt("stride_y", attr.strides.h);
+  op.args_.AddInt("kernel_size_z", attr.kernel.d);
+  op.args_.AddInt("padding_z", -attr.padding.prepended.d);
+  op.args_.AddInt("stride_z", attr.strides.d);
+  const bool stride_correction =
+      definition.IsBatchSupported() && attr.strides.w != 1;
+  if (attr.type == PoolingType::AVERAGE) {
+    op.code_ = GetAveragePoolingKernelCode(definition, stride_correction, &op);
+  } else if (attr.type == PoolingType::MAX) {
+    op.code_ = GetMaxPoolingKernelCode(definition, stride_correction,
+                                       attr.output_indices, &op);
   }
-  if (definition_.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
-    RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
-  }
-  if (definition_.dst_tensors[0].HasAxis(Axis::DEPTH)) {
-    RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z));
-    RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z));
-    RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z));
-  }
-  return absl::OkStatus();
-}
-
-int3 Pooling::GetGridSize() const {
-  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
-  const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
-  const int grid_z = dst_[0]->Slices();
-  return int3(grid_x, grid_y, grid_z);
-}
-
-Pooling CreatePooling(const OperationDef& definition,
-                      const Pooling2DAttributes& attr) {
-  return Pooling(definition, attr);
-}
-
-Pooling CreatePooling(const OperationDef& definition,
-                      const Pooling3DAttributes& attr) {
-  return Pooling(definition, attr);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+  return op;
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h
index 18bb426..81a0dff 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.h
@@ -27,42 +27,11 @@
 namespace gpu {
 namespace cl {
 
-class Pooling : public GPUOperation {
- public:
-  Pooling(const OperationDef& definition, const Pooling2DAttributes& attr);
-  Pooling(const OperationDef& definition, const Pooling3DAttributes& attr);
+GPUOperation CreatePooling(const OperationDef& definition,
+                           const Pooling2DAttributes& attr);
 
-  absl::Status BindArguments() override;
-  int3 GetGridSize() const override;
-
-  // Move only
-  Pooling(Pooling&& kernel);
-  Pooling& operator=(Pooling&& kernel);
-  Pooling(const Pooling&) = delete;
-  Pooling& operator=(const Pooling&) = delete;
-
- private:
-  std::string GetAveragePoolingKernelCode(const OperationDef& op_def,
-                                          bool stride_correction);
-  std::string GetMaxPoolingKernelCode(const OperationDef& op_def,
-                                      bool stride_correction,
-                                      bool output_indices);
-
-  void GenerateCode();
-
-  int4 stride_;
-  int4 padding_;
-  int4 kernel_size_;
-
-  PoolingType type_;
-  bool output_indices_;
-};
-
-Pooling CreatePooling(const OperationDef& definition,
-                      const Pooling2DAttributes& attr);
-
-Pooling CreatePooling(const OperationDef& definition,
-                      const Pooling3DAttributes& attr);
+GPUOperation CreatePooling(const OperationDef& definition,
+                           const Pooling3DAttributes& attr);
 
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/pooling_test.cc
index 12efd56..af99b52 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling_test.cc
@@ -52,7 +52,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      Pooling operation = CreatePooling(op_def, attr);
+      GPUOperation operation = CreatePooling(op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 1, 1, 2), &dst_tensor));
       EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {3.0f, 4.0f}));
@@ -81,7 +81,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      Pooling operation = CreatePooling(op_def, attr);
+      GPUOperation operation = CreatePooling(op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 2, 1), &dst_tensor));
       EXPECT_THAT(dst_tensor.data,
@@ -111,7 +111,7 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      Pooling operation = CreatePooling(op_def, attr);
+      GPUOperation operation = CreatePooling(op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 1, 1, 2), &dst_tensor));
       EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {8.0f, 7.0f}));
@@ -143,7 +143,7 @@
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
       TensorFloat32 dst_tensor_ind;
-      Pooling operation = CreatePooling(op_def, attr);
+      GPUOperation operation = CreatePooling(op_def, attr);
       ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation,
                                     {BHWC(1, 1, 1, 2), BHWC(1, 1, 1, 2)},
                                     {&dst_tensor, &dst_tensor_ind}));
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc
index 7a29d57..bcda1f6 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.cc
@@ -18,47 +18,75 @@
 #include "absl/strings/str_cat.h"
 #include "absl/types/variant.h"
 #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
+#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
 namespace cl {
 
-absl::Status CreatePReLU(const CreationContext& creation_context,
+GPUOperation CreatePReLU(const DeviceInfo& device_info,
                          const OperationDef& definition,
-                         const PReLUAttributes& attr, GPUOperation* result) {
-  *result = GPUOperation(definition);
-  result->elementwise_ = true;
+                         const PReLUAttributes& attr) {
+  GPUOperation result(definition);
+  result.elementwise_ = true;
+
+  std::string alpha_read;
+  auto alpha_linear =
+      absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
+  if (alpha_linear) {
+    TensorLinearDescriptor desc;
+    desc.storage_type =
+        DeduceLinearStorageType(definition.GetPrimaryStorageType());
+    desc.element_type = definition.GetPrimaryDataType();
+    desc.UploadLinearData(*alpha_linear);
+    result.args_.AddObject(
+        "alpha", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+    alpha_read = "FLT4 alpha_val = args.alpha.Read(S_COORD);\n";
+  }
+
+  auto alpha_hwc =
+      absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
+  if (alpha_hwc) {
+    const BHWC shape =
+        BHWC(1, alpha_hwc->shape.h, alpha_hwc->shape.w, alpha_hwc->shape.c);
+    TensorStorageType storage_type = SelectBestStorageType(
+        device_info, shape, definition.GetPrimaryStorageType(),
+        definition.GetDataType(), Layout::HWC);
+    TensorDescriptor desc{definition.GetDataType(), storage_type, Layout::HWC};
+    desc.UploadData(*alpha_hwc);
+    result.args_.AddObject(
+        "alpha", absl::make_unique<TensorDescriptor>(std::move(desc)));
+    const std::string x_coord = shape.w == 1 ? "0" : "X_COORD";
+    const std::string y_coord = shape.h == 1 ? "0" : "Y_COORD";
+    const std::string s_coord = shape.c == 1 ? "0" : "S_COORD";
+    alpha_read = absl::StrCat("FLT4 alpha_val = args.alpha.Read(", x_coord,
+                              ", ", y_coord, ", ", s_coord, ");\n");
+    if (shape.c == 1) {
+      alpha_read += "  alpha_val.y = alpha_val.x;\n";
+      alpha_read += "  alpha_val.z = alpha_val.x;\n";
+      alpha_read += "  alpha_val.w = alpha_val.x;\n";
+    }
+  }
+
   if (attr.clip != 0) {
     if (definition.precision == CalculationsPrecision::F32) {
-      result->args_.AddFloat("clip", attr.clip);
+      result.args_.AddFloat("clip", attr.clip);
     } else {
-      result->args_.AddHalf("clip", half(attr.clip));
+      result.args_.AddHalf("clip", half(attr.clip));
     }
-    result->code_ =
+    result.code_ =
+        alpha_read +
         "in_out_value = clamp(in_out_value, (FLT4)(0.0f), (FLT4)(args.clip)) + "
-        "min((FLT4)(0.0f), in_out_value) * args.alpha.Read(S_COORD);";
+        "min((FLT4)(0.0f), in_out_value) * alpha_val;";
   } else {
-    result->code_ =
+    result.code_ =
+        alpha_read +
         "in_out_value = max((FLT4)(0.0f), in_out_value) + min((FLT4)(0.0f), "
-        "in_out_value) * args.alpha.Read(S_COORD);";
+        "in_out_value) * alpha_val;";
   }
 
-  auto alpha =
-      absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
-  if (!alpha) {
-    return absl::InvalidArgumentError("Alpha is missing");
-  }
-  TensorLinearDescriptor desc;
-  desc.storage_type =
-      DeduceLinearStorageType(definition.GetPrimaryStorageType());
-  desc.element_type = definition.GetPrimaryDataType();
-  desc.UploadLinearData(*alpha);
-
-  result->args_.AddObject(
-      "alpha", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
-
-  return absl::OkStatus();
+  return result;
 }
 
 }  // namespace cl
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h
index b673217..5d2a41b 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu.h
@@ -31,9 +31,9 @@
 namespace gpu {
 namespace cl {
 
-absl::Status CreatePReLU(const CreationContext& creation_context,
+GPUOperation CreatePReLU(const DeviceInfo& device_info,
                          const OperationDef& definition,
-                         const PReLUAttributes& attr, GPUOperation* result);
+                         const PReLUAttributes& attr);
 
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc
index 06ff09c..ef4b8c1 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc
@@ -52,8 +52,8 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      GPUOperation operation;
-      ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation));
+      GPUOperation operation =
+          CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 1, 2), &dst_tensor));
       EXPECT_THAT(dst_tensor.data,
@@ -83,8 +83,8 @@
       op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
       op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
       TensorFloat32 dst_tensor;
-      GPUOperation operation;
-      ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation));
+      GPUOperation operation =
+          CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
       ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
                                     BHWC(1, 2, 1, 2), &dst_tensor));
       EXPECT_THAT(dst_tensor.data,
@@ -93,6 +93,37 @@
   }
 }
 
+TEST_F(OpenCLOperationTest, PReLUHWCAlpha) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 1, 2);
+  src_tensor.data = {0.0f, -1.0f, -2.0f, 3.0f};
+
+  PReLUAttributes attr;
+  ::tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
+  hwc_tensor.shape = HWC(2, 1, 2);
+  hwc_tensor.data = {0.5f, -2.0f, 0.7f, 4.7f};
+  attr.alpha = hwc_tensor;
+  attr.clip = 0.0;
+
+  for (auto storage : env_.GetSupportedStorages()) {
+    for (auto precision : env_.GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      GPUOperation operation =
+          CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
+      ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
+                                    BHWC(1, 2, 1, 2), &dst_tensor));
+      EXPECT_THAT(dst_tensor.data,
+                  Pointwise(FloatNear(eps), {0.0f, 2.0f, -1.4f, 3.0f}));
+    }
+  }
+}
+
 }  // namespace
 }  // namespace cl
 }  // namespace gpu
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc
new file mode 100644
index 0000000..4f889d4
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.cc
@@ -0,0 +1,102 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h"
+
+#include <string>
+
+#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
+#include "tensorflow/lite/delegates/gpu/cl/precision.h"
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+namespace tflite {
+namespace gpu {
+namespace cl {
+namespace {
+std::string GetReduceChannelsKernelCode(const OperationDef& op_def,
+                                        const OperationType& op_type) {
+  std::string c = GetCommonDefines(op_def.precision);
+  if (op_type == OperationType::ADD) {
+    c += "#define OP(a, b) ((a) + (b))\n";
+  } else if (op_type == OperationType::MUL) {
+    c += "#define OP(a, b) ((a) * (b))\n";
+  } else if (op_type == OperationType::MAXIMUM) {
+    c += "#define OP(a, b) max(a, b)\n";
+  } else if (op_type == OperationType::MINIMUM) {
+    c += "#define OP(a, b) min(a, b)\n";
+  }
+  c += "__kernel void main_function($0) {\n";
+  c += "  int X = get_global_id(0);\n";
+  c += "  int Y = get_global_id(1);\n";
+  c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
+       "return;\n";
+  if (op_type == OperationType::ADD) {
+    c += "  FLT4 reduced = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+  } else if (op_type == OperationType::MUL) {
+    c += "  FLT4 reduced = (FLT4)(1.0f, 1.0f, 1.0f, 1.0f);\n";
+  } else {
+    c += "  FLT4 V0 = args.src_tensor.Read(X, Y, 0);\n";
+    c += "  FLT4 reduced = (FLT4)(V0.x, V0.x, V0.x, V0.x);\n";
+  }
+  c += "  int s = 0;\n";
+  c += "  for (; s < args.src_tensor.Slices() - 1; ++s) {\n";
+  c += "    FLT4 V = args.src_tensor.Read(X, Y, s);\n";
+  c += "    reduced = OP(reduced, V);\n";
+  c += "  }\n";
+  c += "  FLT reduced_final = OP(OP(reduced.x, reduced.y), OP(reduced.z, "
+       "reduced.w));\n";
+  c += "  FLT last_reduce;\n";
+  c += "  FLT4 last_val = args.src_tensor.Read(X, Y, s);\n";
+  c += "  int ch_rem = args.src_tensor.Channels() % 4;\n";
+  c += "  if (ch_rem == 0) {\n";
+  c += "    last_reduce = OP(OP(last_val.x, last_val.y), OP(last_val.z, "
+       "last_val.w));\n";
+  c += "  } else if (ch_rem == 1) {\n";
+  c += "    last_reduce = OP(OP(last_val.x, last_val.y), last_val.z);\n";
+  c += "  } else if (ch_rem == 2) {\n";
+  c += "    last_reduce = OP(last_val.x, last_val.y);\n";
+  c += "  } else {\n";
+  c += "    last_reduce = last_val.x;\n";
+  c += "  }\n";
+  c += "  reduced_final = OP(reduced_final, last_reduce);\n";
+  c += "  FLT4 result = (FLT4)(reduced_final, 0.0f, 0.0f, 0.0f);\n";
+  c += "  args.dst_tensor.Write(result, X, Y, 0);\n";
+  c += "}\n";
+  return c;
+}
+}  // namespace
+
+GPUOperation CreateReduce(const OperationDef& definition,
+                          const OperationType& op_type) {
+  GPUOperation op(definition);
+  auto src_desc = definition.src_tensors[0];
+  if (definition.IsBatchSupported()) {
+    src_desc.SetStateVar("BatchedWidth", "true");
+  }
+  op.AddSrcTensor("src_tensor", src_desc);
+  auto dst_desc = definition.dst_tensors[0];
+  if (definition.IsBatchSupported()) {
+    dst_desc.SetStateVar("BatchedWidth", "true");
+  }
+  op.AddDstTensor("dst_tensor", dst_desc);
+  op.code_ = GetReduceChannelsKernelCode(definition, op_type);
+  op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1;
+  return op;
+}
+
+}  // namespace cl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h
new file mode 100644
index 0000000..ec5329a
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce.h
@@ -0,0 +1,33 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_
+
+#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+
+namespace tflite {
+namespace gpu {
+namespace cl {
+
+GPUOperation CreateReduce(const OperationDef& definition,
+                          const OperationType& op_type);
+
+}  // namespace cl
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_REDUCE_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc
new file mode 100644
index 0000000..9275c45
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/reduce_test.cc
@@ -0,0 +1,129 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h"
+
+#include <cmath>
+#include <cstdlib>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+using ::testing::FloatNear;
+using ::testing::Pointwise;
+
+namespace tflite {
+namespace gpu {
+namespace cl {
+namespace {
+
+TEST_F(OpenCLOperationTest, ReduceSumChannels) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 1, 5);
+  src_tensor.data = {1.1, 2.1, 0.7, 0.3, 1.2, 3.1, 4.1, 0.0, 1.0, 4.4};
+
+  for (auto storage : env_.GetSupportedStorages()) {
+    for (auto precision : env_.GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      GPUOperation operation = CreateReduce(op_def, OperationType::ADD);
+      ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
+                                    BHWC(1, 2, 1, 1), &dst_tensor));
+      EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {5.4f, 12.6f}));
+    }
+  }
+}
+
+TEST_F(OpenCLOperationTest, ReduceProductChannels) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 1, 2);
+  src_tensor.data = {1.1, 2.0, 3.1, 4.0};
+
+  for (auto storage : env_.GetSupportedStorages()) {
+    for (auto precision : env_.GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      GPUOperation operation = CreateReduce(op_def, OperationType::MUL);
+      ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
+                                    BHWC(1, 2, 1, 1), &dst_tensor));
+      EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {2.2f, 12.4f}));
+    }
+  }
+}
+
+TEST_F(OpenCLOperationTest, ReduceMaxChannels) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 1, 6);
+  src_tensor.data = {1.1,  2.0,  -0.3, -100.0, 32.6, 1.1,
+                     -3.1, -4.0, -5.0, -7.0,   -2.0, -100.0};
+
+  for (auto storage : env_.GetSupportedStorages()) {
+    for (auto precision : env_.GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      GPUOperation operation = CreateReduce(op_def, OperationType::MAXIMUM);
+      ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
+                                    BHWC(1, 2, 1, 1), &dst_tensor));
+      EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {32.6f, -2.0f}));
+    }
+  }
+}
+
+TEST_F(OpenCLOperationTest, ReduceMinChannels) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 1, 6);
+  src_tensor.data = {1.1,  2.0,  -0.3, -100.0, 32.6, 1.1,
+                     -3.1, -4.0, -5.0, -7.0,   -2.0, 100.0};
+
+  for (auto storage : env_.GetSupportedStorages()) {
+    for (auto precision : env_.GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      GPUOperation operation = CreateReduce(op_def, OperationType::MINIMUM);
+      ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
+                                    BHWC(1, 2, 1, 1), &dst_tensor));
+      EXPECT_THAT(dst_tensor.data, Pointwise(FloatNear(eps), {-100.0f, -7.0f}));
+    }
+  }
+}
+
+}  // namespace
+}  // namespace cl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
index b7cfa5f..f0e0c41 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc
@@ -84,6 +84,17 @@
                           batch_size, stride_x, padding_x);
 }
 
+std::string GetXStrideCorrectedV2(const std::string& src_x,
+                                  const std::string& batch_size,
+                                  const std::string& stride_x,
+                                  const std::string& padding_x) {
+  // int p0 = src_x / batch_size;\n";
+  // int b0 = src_x % batch_size;\n";
+  // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
+  return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x,
+                          batch_size, stride_x, padding_x);
+}
+
 float4 GetMaskForLastPlane(int channels) {
   float4 mask = float4(0.0f);
   const int reminder = channels % 4 == 0 ? 4 : channels % 4;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.h b/tensorflow/lite/delegates/gpu/cl/kernels/util.h
index b1dd4fe..aa9f599 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/util.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.h
@@ -44,6 +44,14 @@
                                 const std::string& stride_x,
                                 const std::string& padding_x);
 
+// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
+// with B after W (for example HWBC4) and WB stored in one axis of GPU
+// resources.
+std::string GetXStrideCorrectedV2(const std::string& src_x,
+                                  const std::string& batch_size,
+                                  const std::string& stride_x,
+                                  const std::string& padding_x);
+
 template <DataType S, typename T>
 void RearrangeWeightsToOHWIOGroupI4O4(
     const tflite::gpu::Tensor<OHWI, S>& weights, int out_group_size,
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
index eab957e..7fa7978 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc
@@ -35,8 +35,8 @@
     const DeviceInfo& device_info, const OperationDef& op_def,
     ModelHints hints) {
   if (IsConvConstantsSupported(device_info, op_def, attr)) {
-    ConvConstants conv = CreateConvConstants(device_info, op_def, attr);
-    return absl::make_unique<ConvConstants>(std::move(conv));
+    GPUOperation conv = CreateConvConstants(device_info, op_def, attr);
+    return absl::make_unique<GPUOperation>(std::move(conv));
   } else {
     ConvTexture conv = CreateConvTexture(device_info, op_def, attr);
     return absl::make_unique<ConvTexture>(std::move(conv));
@@ -66,8 +66,8 @@
     const Convolution2DAttributes& attr, const BHWC& dst_shape,
     const DeviceInfo& device_info, const OperationDef& op_def) {
   if (IsConvConstantsSupported(device_info, op_def, attr)) {
-    ConvConstants conv = CreateConvConstants(device_info, op_def, attr);
-    return absl::make_unique<ConvConstants>(std::move(conv));
+    GPUOperation conv = CreateConvConstants(device_info, op_def, attr);
+    return absl::make_unique<GPUOperation>(std::move(conv));
   } else {
     ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape);
     return absl::make_unique<ConvPowerVR>(std::move(conv));
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc
index 2d61def..b04335a 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc
@@ -33,8 +33,8 @@
     return absl::make_unique<DepthwiseConv3x3>(
         CreateDepthwiseConv3x3(device_info, op_def, attr));
   } else {
-    return absl::make_unique<DepthwiseConvolution>(
-        CreateDepthwiseConvolution(device_info, op_def, attr));
+    return absl::make_unique<GPUOperation>(
+        CreateDepthwiseConvolution2D(device_info, op_def, attr));
   }
 }
 
@@ -45,8 +45,8 @@
     return absl::make_unique<DepthwiseConv3x3>(
         CreateDepthwiseConv3x3(device_info, op_def, attr));
   } else {
-    return absl::make_unique<DepthwiseConvolution>(
-        CreateDepthwiseConvolution(device_info, op_def, attr));
+    return absl::make_unique<GPUOperation>(
+        CreateDepthwiseConvolution2D(device_info, op_def, attr));
   }
 }
 
@@ -62,8 +62,8 @@
     return absl::make_unique<DepthwiseConv3x3>(
         CreateDepthwiseConv3x3(device_info, op_def, attr));
   } else {
-    return absl::make_unique<DepthwiseConvolution>(
-        CreateDepthwiseConvolution(device_info, op_def, attr));
+    return absl::make_unique<GPUOperation>(
+        CreateDepthwiseConvolution2D(device_info, op_def, attr));
   }
 }
 }  // namespace
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
index dc18cde..5e8e4a9 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc
@@ -127,7 +127,7 @@
 
 }  // namespace
 
-absl::Status GPUOperationFromNode(const CreationContext& creation_context,
+absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
                                   const OperationDef& op_def, ModelHints hints,
                                   const std::vector<Value*>& inputs,
                                   const std::vector<Value*>& outputs,
@@ -156,8 +156,8 @@
       } else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
         auto attr =
             absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
-        GPUOperation operation = CreateElementwise(
-            creation_context.GetDeviceInfo(), op_def, op_type, attr);
+        GPUOperation operation =
+            CreateElementwise(device_info, op_def, op_type, attr);
         *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
         return absl::OkStatus();
       }
@@ -170,8 +170,7 @@
       for (int i = 0; i < inputs.size(); ++i) {
         channels[i] = inputs[i]->tensor.shape.c;
       }
-      return SelectConcat(attr, channels, op_def,
-                          creation_context.device->info_, gpu_op);
+      return SelectConcat(attr, channels, op_def, device_info, gpu_op);
     }
     case OperationType::CONVOLUTION_2D: {
       auto attr =
@@ -179,16 +178,14 @@
       auto input_shape = inputs[0]->tensor.shape;
       auto output_shape = outputs[0]->tensor.shape;
       if (inputs.size() == 1) {
-        if (WinogradFromNode(creation_context.GetDeviceInfo(), inputs, outputs,
-                             op_def, hints, input_shape, output_shape, attr,
-                             gpu_subgraph)
+        if (WinogradFromNode(device_info, inputs, outputs, op_def, hints,
+                             input_shape, output_shape, attr, gpu_subgraph)
                 .ok()) {
           return absl::OkStatus();
         } else {
           gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
-          *gpu_op = SelectConvolution(attr, output_shape,
-                                      creation_context.GetDeviceInfo(), op_def,
-                                      hints);
+          *gpu_op =
+              SelectConvolution(attr, output_shape, device_info, op_def, hints);
           return absl::OkStatus();
         }
       } else {
@@ -206,8 +203,8 @@
         conv_def.src_tensors[1] = weights_desc;
         ConvWeightsDescription conv_weights_desc;
         conv_op.operation = SelectConvolutionWithDynamicWeights(
-            attr, weights_shape, output_shape, creation_context.GetDeviceInfo(),
-            conv_def, hints, &conv_weights_desc);
+            attr, weights_shape, output_shape, device_info, conv_def, hints,
+            &conv_weights_desc);
 
         int aligned_output =
             AlignByN(weights_shape.b, conv_weights_desc.output_group_size * 4);
@@ -232,41 +229,39 @@
     case OperationType::CONVOLUTION_TRANSPOSED: {
       auto attr = absl::any_cast<ConvolutionTransposedAttributes>(
           node.operation.attributes);
-      *gpu_op = SelectConvolutionTransposed(
-          attr, creation_context.GetDeviceInfo(), op_def);
+      *gpu_op = SelectConvolutionTransposed(attr, device_info, op_def);
       return absl::OkStatus();
     }
     case OperationType::DEPTHWISE_CONVOLUTION: {
       auto attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
           node.operation.attributes);
-      *gpu_op =
-          SelectDWConvolution(attr, creation_context.GetDeviceInfo(), op_def);
+      *gpu_op = SelectDWConvolution(attr, device_info, op_def);
       return absl::OkStatus();
     }
     case OperationType::FULLY_CONNECTED: {
       auto attr =
           absl::any_cast<FullyConnectedAttributes>(node.operation.attributes);
-      *gpu_op = SelectFullyConnected(attr, creation_context.GetDeviceInfo(),
-                                     op_def, inputs[0]->tensor.shape.b);
+      *gpu_op = SelectFullyConnected(attr, device_info, op_def,
+                                     inputs[0]->tensor.shape.b);
       return absl::OkStatus();
     }
     case OperationType::LSTM: {
-      SelectLSTM(op_def, creation_context.device->info_, gpu_op);
+      *gpu_op = SelectLSTM(op_def, device_info);
       return absl::OkStatus();
     }
     case OperationType::MAX_UNPOOLING_2D: {
       auto attr =
           absl::any_cast<MaxUnpooling2DAttributes>(node.operation.attributes);
-      SelectMaxUnpooling(attr, op_def, gpu_op);
+      *gpu_op = SelectMaxUnpooling(attr, op_def);
       return absl::OkStatus();
     }
     case OperationType::MEAN: {
       auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
-      return SelectMean(attr, op_def, creation_context.device->info_, gpu_op);
+      return SelectMean(attr, op_def, device_info, gpu_op);
     }
     case OperationType::MEAN_STDDEV_NORMALIZATION: {
       MeanStdDevNormalization operation =
-          CreateMeanStdDevNormalization(op_def, creation_context.device->info_);
+          CreateMeanStdDevNormalization(op_def, device_info);
       *gpu_op =
           absl::make_unique<MeanStdDevNormalization>(std::move(operation));
       return absl::OkStatus();
@@ -279,12 +274,13 @@
     case OperationType::POOLING_2D: {
       auto attr =
           absl::any_cast<Pooling2DAttributes>(node.operation.attributes);
-      SelectPooling(attr, op_def, gpu_op);
+      *gpu_op = SelectPooling(attr, op_def);
       return absl::OkStatus();
     }
     case OperationType::PRELU: {
       auto attr = absl::any_cast<PReLUAttributes>(node.operation.attributes);
-      return SelectPReLU(attr, creation_context, op_def, gpu_op);
+      *gpu_op = SelectPReLU(attr, device_info, op_def);
+      return absl::OkStatus();
     }
     case OperationType::QUANTIZE_AND_DEQUANTIZE: {
       auto attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
@@ -360,8 +356,8 @@
       } else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
         auto attr =
             absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
-        GPUOperation operation = CreateElementwise(
-            creation_context.GetDeviceInfo(), op_def, op_type, attr);
+        GPUOperation operation =
+            CreateElementwise(device_info, op_def, op_type, attr);
         *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
         return absl::OkStatus();
       }
@@ -369,8 +365,8 @@
           "No support of ", node.operation.type, " with this parameters"));
     }
     default:
-      return SelectDefault(creation_context.device->info_, op_def, hints,
-                           inputs, outputs, node, gpu_subgraph);
+      return SelectDefault(device_info, op_def, hints, inputs, outputs, node,
+                           gpu_subgraph);
   }
 }
 
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
index f237a38..640432e 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h
@@ -29,7 +29,7 @@
 namespace gpu {
 namespace cl {
 
-absl::Status GPUOperationFromNode(const CreationContext& creation_context,
+absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
                                   const OperationDef& op_def, ModelHints hints,
                                   const std::vector<Value*>& inputs,
                                   const std::vector<Value*>& outputs,
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
index 5f2f8f0..4dbb1ff 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc
@@ -45,10 +45,9 @@
 namespace gpu {
 namespace cl {
 
-void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
-                std::unique_ptr<GPUOperation>* ptr) {
-  LSTM operation = CreateLSTM(op_def, device_info);
-  *ptr = absl::make_unique<LSTM>(std::move(operation));
+std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
+                                         const DeviceInfo& device_info) {
+  return absl::make_unique<GPUOperation>(CreateLSTM(op_def, device_info));
 }
 
 std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
@@ -56,27 +55,21 @@
   return absl::make_unique<GPUOperation>(CreateReLU(op_def, attr));
 }
 
-absl::Status SelectPReLU(const PReLUAttributes& attr,
-                         const CreationContext& creation_context,
-                         const OperationDef& op_def,
-                         std::unique_ptr<GPUOperation>* ptr) {
-  GPUOperation operation;
-  RETURN_IF_ERROR(CreatePReLU(creation_context, op_def, attr, &operation));
-  *ptr = absl::make_unique<GPUOperation>(std::move(operation));
-  return absl::OkStatus();
+std::unique_ptr<GPUOperation> SelectPReLU(const PReLUAttributes& attr,
+                                          const DeviceInfo& device_info,
+                                          const OperationDef& op_def) {
+  return absl::make_unique<GPUOperation>(
+      CreatePReLU(device_info, op_def, attr));
 }
 
-void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def,
-                   std::unique_ptr<GPUOperation>* ptr) {
-  Pooling pooling = CreatePooling(op_def, attr);
-  *ptr = absl::make_unique<Pooling>(std::move(pooling));
+std::unique_ptr<GPUOperation> SelectPooling(const Pooling2DAttributes& attr,
+                                            const OperationDef& op_def) {
+  return absl::make_unique<GPUOperation>(CreatePooling(op_def, attr));
 }
 
-void SelectMaxUnpooling(const MaxUnpooling2DAttributes& attr,
-                        const OperationDef& op_def,
-                        std::unique_ptr<GPUOperation>* ptr) {
-  MaxUnpooling operation = CreateMaxUnpooling(op_def, attr);
-  *ptr = absl::make_unique<MaxUnpooling>(std::move(operation));
+std::unique_ptr<GPUOperation> SelectMaxUnpooling(
+    const MaxUnpooling2DAttributes& attr, const OperationDef& op_def) {
+  return absl::make_unique<GPUOperation>(CreateMaxUnpooling(op_def, attr));
 }
 
 void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
index 71d4c1f..c6c604d 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h
@@ -28,23 +28,21 @@
 namespace gpu {
 namespace cl {
 
-void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
-                std::unique_ptr<GPUOperation>* ptr);
+std::unique_ptr<GPUOperation> SelectLSTM(const OperationDef& op_def,
+                                         const DeviceInfo& device_info);
 
 std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
                                          const OperationDef& op_def);
 
-absl::Status SelectPReLU(const PReLUAttributes& attr,
-                         const CreationContext& creation_context,
-                         const OperationDef& op_def,
-                         std::unique_ptr<GPUOperation>* ptr);
+std::unique_ptr<GPUOperation> SelectPReLU(const PReLUAttributes& attr,
+                                          const DeviceInfo& device_info,
+                                          const OperationDef& op_def);
 
-void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def,
-                   std::unique_ptr<GPUOperation>* ptr);
+std::unique_ptr<GPUOperation> SelectPooling(const Pooling2DAttributes& attr,
+                                            const OperationDef& op_def);
 
-void SelectMaxUnpooling(const MaxUnpooling2DAttributes& attr,
-                        const OperationDef& op_def,
-                        std::unique_ptr<GPUOperation>* ptr);
+std::unique_ptr<GPUOperation> SelectMaxUnpooling(
+    const MaxUnpooling2DAttributes& attr, const OperationDef& op_def);
 
 void SelectAdd(const OperationDef& op_def, const std::vector<int>& channels,
                int dst_channels, std::unique_ptr<GPUOperation>* ptr);
diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh
index 0fd2d33..56d1e10 100755
--- a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh
+++ b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh
@@ -83,11 +83,17 @@
 declare -a BUILD_CONFIG
 abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r')
 if [[ "$abi_version" == "armeabi-v7a" ]]; then
-#"32 bit"
+#"32 bit ARM"
 BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie )
-else
-#"64 bit"
+elif [[ "$abi_version" == "arm64-v8a" ]]; then
+#"64 bit ARM"
 BUILD_CONFIG=( --config=android_arm64 -c opt )
+elif [[ "$abi_version" == "x86_64" ]]; then
+# x86_64
+BUILD_CONFIG=( --config=android_x86_64 -c opt )
+else
+echo "Error: Unknown processor ABI"
+exit 1
 fi
 
 bazel build "${BUILD_CONFIG[@]}" //$SHELL_DIR:$BINARY_NAME
diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD
index 60a0fda..3b14d14 100644
--- a/tensorflow/lite/delegates/gpu/common/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/BUILD
@@ -10,6 +10,7 @@
     srcs = ["convert.cc"],
     hdrs = ["convert.h"],
     deps = [
+        ":data_type",
         ":shape",
         ":status",
         ":tensor",
@@ -73,6 +74,7 @@
         ":types",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/status",
     ],
 )
 
@@ -81,7 +83,6 @@
     srcs = ["model.cc"],
     hdrs = ["model.h"],
     deps = [
-        ":data_type",
         ":shape",
         ":status",
         ":tensor",
@@ -97,6 +98,7 @@
     srcs = ["model_test.cc"],
     deps = [
         ":model",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
     ],
 )
@@ -115,10 +117,10 @@
         ":shape",
         ":status",
         ":tensor",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
         "//tensorflow/lite/delegates:utils",
-        "//tensorflow/lite:context",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite:util",
         "//tensorflow/lite/c:common",
@@ -133,10 +135,14 @@
     name = "model_builder_test",
     srcs = ["model_builder_test.cc"],
     deps = [
+        ":data_type",
         ":model_builder",
+        ":shape",
+        ":tensor",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite/c:common",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
     ],
 )
@@ -152,10 +158,8 @@
         ":shape",
         ":status",
         ":tensor",
-        "//tensorflow/lite:context",
         "//tensorflow/lite:kernel_api",
         "//tensorflow/lite/c:common",
-        "//tensorflow/lite/delegates:utils",
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/kernels/internal:reference_base",
         "//tensorflow/lite/kernels/internal:tensor",
@@ -186,10 +190,12 @@
         ":model",
         ":model_builder_helper",
         ":status",
+        ":tensor",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/delegates:utils",
         "//tensorflow/lite/kernels:kernel_util",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -199,9 +205,9 @@
     hdrs = ["operations.h"],
     deps = [
         ":data_type",
-        ":model",
         ":shape",
         ":status",
+        ":tensor",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/types:variant",
     ],
@@ -213,11 +219,12 @@
     hdrs = ["quantization_util.h"],
     deps = [
         ":status",
-        "//tensorflow/lite:kernel_api",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels/internal:optimized_base",
+        "//tensorflow/lite/kernels/internal:tensor",
         "//tensorflow/lite/kernels/internal:types",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status",
     ],
 )
 
@@ -227,6 +234,9 @@
     deps = [
         ":quantization_util",
         "//tensorflow/lite:util",
+        "//tensorflow/lite/c:common",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
     ],
 )
@@ -237,10 +247,7 @@
     name = "shape",
     srcs = ["shape.cc"],
     hdrs = ["shape.h"],
-    deps = [
-        "@com_google_absl//absl/hash",
-        "@com_google_absl//absl/strings",
-    ],
+    deps = ["@com_google_absl//absl/strings"],
 )
 
 cc_test(
@@ -288,6 +295,9 @@
     srcs = ["memory_management_test.cc"],
     deps = [
         ":memory_management",
+        ":shape",
+        ":types",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
     ],
 )
@@ -334,9 +344,5 @@
     name = "workgroup_selection",
     srcs = ["workgroup_selection.cc"],
     hdrs = ["workgroup_selection.h"],
-    deps = [
-        ":status",
-        ":types",
-        ":util",
-    ],
+    deps = [":util"],
 )
diff --git a/tensorflow/lite/delegates/gpu/common/convert.cc b/tensorflow/lite/delegates/gpu/common/convert.cc
index fb0caf9..3920692 100644
--- a/tensorflow/lite/delegates/gpu/common/convert.cc
+++ b/tensorflow/lite/delegates/gpu/common/convert.cc
@@ -15,9 +15,19 @@
 
 #include "tensorflow/lite/delegates/gpu/common/convert.h"
 
+#include <stdint.h>
+#include <string.h>
+
+#include <string>
+#include <vector>
+
 #include <fp16.h>
 #include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/types.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 
diff --git a/tensorflow/lite/delegates/gpu/common/convert.h b/tensorflow/lite/delegates/gpu/common/convert.h
index 3aba9c9..c7a6c17 100644
--- a/tensorflow/lite/delegates/gpu/common/convert.h
+++ b/tensorflow/lite/delegates/gpu/common/convert.h
@@ -16,9 +16,12 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CONVERT_H_
 
+#include <stdint.h>
+
 #include <vector>
 
 #include "absl/types/span.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
diff --git a/tensorflow/lite/delegates/gpu/common/custom_parsers.h b/tensorflow/lite/delegates/gpu/common/custom_parsers.h
index d70e584..2644864 100644
--- a/tensorflow/lite/delegates/gpu/common/custom_parsers.h
+++ b/tensorflow/lite/delegates/gpu/common/custom_parsers.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_CUSTOM_PARSERS_H_
 
-#include <string>
+#include <stdint.h>
 
 #include "absl/strings/string_view.h"
 #include "absl/types/any.h"
diff --git a/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc
index 5aa1303..a4981a9 100644
--- a/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc
+++ b/tensorflow/lite/delegates/gpu/common/default/custom_parsers.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
 
+#include <stdint.h>
+
 #include <string>
 
 #include "absl/strings/str_cat.h"
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
index 14fb48a..b56745d 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
@@ -15,8 +15,6 @@
 
 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
 
-#include <algorithm>
-#include <cctype>
 #include <string>
 
 #include "absl/strings/ascii.h"
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.cc b/tensorflow/lite/delegates/gpu/common/memory_management.cc
index d7e6a06..2a637d5 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management.cc
@@ -15,17 +15,21 @@
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
 
-#include <algorithm>
-#include <limits>
+#include <cstddef>
 #include <numeric>
-#include <queue>
-#include <set>
-#include <type_traits>
+#include <utility>
 #include <vector>
 
+#include "tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management.h b/tensorflow/lite/delegates/gpu/common/memory_management.h
index 7df4947..9f1adce 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management.h
@@ -16,16 +16,12 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_H_
 
-#include <cstdint>
-#include <memory>
+#include <stddef.h>
+
 #include <vector>
 
 #include "absl/memory/memory.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h"
-#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h"
-#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h"
-#include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h"
-#include "tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h
index fdccce5..018e5a9 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/equality_assignment.h
@@ -16,6 +16,9 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_EQUALITY_ASSIGNMENT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_EQUALITY_ASSIGNMENT_H_
 
+#include <stddef.h>
+
+#include <cstddef>
 #include <queue>
 #include <vector>
 
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc
index 2c138b4..b07ab61 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.cc
@@ -16,12 +16,13 @@
 #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h"
 
 #include <algorithm>
-#include <cstdint>
-#include <cstdlib>
+#include <cstddef>
 #include <set>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h
index 4703522..e207ab3 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h
@@ -16,7 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_BREADTH_ASSIGNMENT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_BREADTH_ASSIGNMENT_H_
 
-#include <cstdint>
+#include <stddef.h>
+
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
index 76309ce..130f271 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc
@@ -16,8 +16,13 @@
 #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h"
 
 #include <algorithm>
+#include <cstddef>
+#include <iterator>
+#include <vector>
 
+#include "absl/status/status.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h
index b0ad9d1..198a25c 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_SIZE_ASSIGNMENT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_BY_SIZE_ASSIGNMENT_H_
 
+#include <stddef.h>
+
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h
index 8c3719e..048ed38 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_in_order_assignment.h
@@ -16,7 +16,11 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_IN_ORDER_ASSIGNMENT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_GREEDY_IN_ORDER_ASSIGNMENT_H_
 
+#include <stddef.h>
+
 #include <algorithm>
+#include <cstddef>
+#include <iterator>
 #include <list>
 #include <queue>
 #include <set>
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc b/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc
index bbcd373..27126aa 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/internal.cc
@@ -16,6 +16,11 @@
 #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h"
 
 #include <algorithm>
+#include <cstddef>
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/internal.h b/tensorflow/lite/delegates/gpu/common/memory_management/internal.h
index 702fd29..4d48f75 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/internal.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/internal.h
@@ -16,9 +16,9 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_INTERNAL_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_INTERNAL_H_
 
-#include <cstdint>
+#include <stddef.h>
+
 #include <limits>
-#include <memory>
 #include <vector>
 
 #include "absl/memory/memory.h"
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc
index 757cb89..ed83e3c 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/internal_test.cc
@@ -15,8 +15,12 @@
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h"
 
+#include <cstddef>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc
index 059c23f..c56ac2e 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.cc
@@ -16,11 +16,15 @@
 #include "tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h"
 
 #include <algorithm>
+#include <cstddef>
+#include <limits>
 #include <queue>
-#include <set>
+#include <utility>
 #include <vector>
 
+#include "absl/status/status.h"
 #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h
index 1284c12..df734ad 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/min_cost_flow_assignment.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_MIN_COST_FLOW_ASSIGNMENT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_MIN_COST_FLOW_ASSIGNMENT_H_
 
+#include <stddef.h>
+
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h
index 8a00c67..d700f62 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/naive_assignment.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_NAIVE_ASSIGNMENT_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_NAIVE_ASSIGNMENT_H_
 
+#include <stddef.h>
+
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h"
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/types.cc b/tensorflow/lite/delegates/gpu/common/memory_management/types.cc
index 5cec0ca..101ca53 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/types.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/types.cc
@@ -16,7 +16,7 @@
 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
 
 #include <algorithm>
-#include <cstdint>
+#include <cstddef>
 #include <queue>
 #include <vector>
 
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/types.h b/tensorflow/lite/delegates/gpu/common/memory_management/types.h
index a511152..f3257fc 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/types.h
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/types.h
@@ -16,8 +16,9 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_TYPES_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_TYPES_H_
 
-#include <cstdint>
-#include <memory>
+#include <stddef.h>
+
+#include <cstddef>
 #include <vector>
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc
index 0312dc2..22558ec 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management/types_test.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
 
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
diff --git a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
index 12f5b6e..ba95135 100644
--- a/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/memory_management_test.cc
@@ -15,8 +15,15 @@
 
 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
 
+#include <cstddef>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/types.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/model.cc b/tensorflow/lite/delegates/gpu/common/model.cc
index a2f9da4..7414aba 100644
--- a/tensorflow/lite/delegates/gpu/common/model.cc
+++ b/tensorflow/lite/delegates/gpu/common/model.cc
@@ -15,7 +15,20 @@
 
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 
+#include <stdint.h>
+
+#include <algorithm>
+#include <iterator>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/model.h b/tensorflow/lite/delegates/gpu/common/model.h
index f6d1609..adf93d4 100644
--- a/tensorflow/lite/delegates/gpu/common/model.h
+++ b/tensorflow/lite/delegates/gpu/common/model.h
@@ -24,10 +24,8 @@
 #include <vector>
 
 #include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
 #include "absl/types/any.h"
 #include "absl/types/optional.h"
-#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index efa75a2..04503c8 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -27,6 +27,7 @@
 #include <variant>
 #include <vector>
 
+#include "absl/base/attributes.h"
 #include "absl/container/flat_hash_map.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
@@ -797,6 +798,7 @@
       case OperationType::ABS:
       case OperationType::COPY:
       case OperationType::COS:
+      case OperationType::ELU:
       case OperationType::EXP:
       case OperationType::LOG:
       case OperationType::RSQRT:
@@ -814,6 +816,8 @@
   bool IsTwoArgumentOperation() const {
     switch (operation_type_) {
       case OperationType::DIV:
+      case OperationType::MAXIMUM:
+      case OperationType::MINIMUM:
       case OperationType::POW:
       case OperationType::SQUARED_DIFF:
       case OperationType::SUB:
@@ -825,8 +829,11 @@
 
   bool IsTwoArgumentOperationWithConst() const {
     switch (operation_type_) {
-      case OperationType::MINIMUM:
+      case OperationType::DIV:
       case OperationType::MAXIMUM:
+      case OperationType::MINIMUM:
+      case OperationType::POW:
+      case OperationType::SQUARED_DIFF:
       case OperationType::SUB:
         return true;
       default:
@@ -1124,6 +1131,17 @@
     // The "larger" input tensor must be bound to 1st input and the "smaller"
     // input tensor ("mask") must be bound to 2nd input.
     if (runtime_tensor0 && runtime_tensor1) {
+      if (input0 == input1) {
+        // replace MUL(A, A) with POW(A, 2.0)
+        // TODO(b/166831113): Support the same inputs for operations.
+        node->operation.type = ToString(OperationType::POW);
+        ElementwiseAttributes attr;
+        attr.param = 2.0f;
+        node->operation.attributes = std::move(attr);
+        RETURN_IF_ERROR(reader->AddInput(node, 0));
+        return reader->AddOutputs(node);
+      }
+
       BHWC shape0;
       RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
       BHWC shape1;
@@ -2389,6 +2407,7 @@
   absl::Status IsSupported(const TfLiteContext* context,
                            const TfLiteNode* tflite_node,
                            const TfLiteRegistration* registration) final {
+    RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
     return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
                               /*outputs=*/1);
   }
@@ -2414,37 +2433,6 @@
   }
 };
 
-class Landmarks2TransformMatrixV2OperationParser
-    : public TFLiteOperationParser {
- public:
-  absl::Status IsSupported(const TfLiteContext* context,
-                           const TfLiteNode* tflite_node,
-                           const TfLiteRegistration* registration) final {
-    return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
-                              /*outputs=*/1);
-  }
-
-  absl::Status Parse(const TfLiteNode* tflite_node,
-                     const TfLiteRegistration* registration,
-                     GraphFloat32* graph, ObjectReader* reader) final {
-    Node* node = graph->NewNode();
-    RETURN_IF_ERROR(reader->AddInput(node, 0));  // landmarks
-    RETURN_IF_ERROR(reader->AddOutputs(node));   // transform matrix
-
-    const std::string op_name = "landmarks_to_transform_matrix_v2";
-    node->operation.type = op_name;
-    BHWC output_shape;
-    RETURN_IF_ERROR(ParseCustomAttributes(
-        op_name, registration->version, tflite_node->custom_initial_data,
-        tflite_node->custom_initial_data_size, &(node->operation.attributes),
-        &output_shape));
-
-    auto output_value = graph->FindOutputs(node->id)[0];
-    output_value->tensor.shape = output_shape;
-    return absl::OkStatus();
-  }
-};
-
 class AlignmentPointsToTransformMatrixOperationParser
     : public TFLiteOperationParser {
  public:
@@ -2688,12 +2676,10 @@
       if (custom_name == "TransformLandmarksV2") {
         return std::make_unique<TransformLandmarksV2OperationParser>();
       }
-      if (custom_name == "Landmarks2TransformMatrix") {
+      if (custom_name == "Landmarks2TransformMatrix" ||
+          custom_name == "Landmarks2TransformMatrixV2") {
         return std::make_unique<Landmarks2TransformMatrixOperationParser>();
       }
-      if (custom_name == "Landmarks2TransformMatrixV2") {
-        return std::make_unique<Landmarks2TransformMatrixV2OperationParser>();
-      }
       if (custom_name == "AlignmentPointsToTransformMatrix") {
         return std::make_unique<
             AlignmentPointsToTransformMatrixOperationParser>();
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.h b/tensorflow/lite/delegates/gpu/common/model_builder.h
index 9d80e96..ab18f05 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.h
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.h
@@ -16,13 +16,12 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_H_
 
-#include <cstdint>
-#include <string>
-
 #include "absl/container/flat_hash_map.h"
-#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
index b030fb7..16b3d3b 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
@@ -15,19 +15,27 @@
 
 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
 
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+
+#include <any>
+#include <limits>
 #include <string>
+#include <vector>
 
 #include <fp16.h>
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
-#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/context.h"
 #include "tensorflow/lite/context_util.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/utils.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h
index 849ef04..aa2e892 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h
@@ -16,6 +16,10 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
 
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h>
+
 #include <string>
 
 #include "absl/strings/str_cat.h"
@@ -29,7 +33,6 @@
 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
index c5ee71b..9bc848b 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc
@@ -15,15 +15,21 @@
 
 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
 
-#include <cstdlib>
+#include <stddef.h>
+#include <stdint.h>
 
-#include <gmock/gmock.h>
+#include <cstdlib>
+#include <utility>
+#include <vector>
+
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 #include "tensorflow/lite/builtin_ops.h"
-#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/stderr_reporter.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/model_test.cc b/tensorflow/lite/delegates/gpu/common/model_test.cc
index 87f65eb..c63babc 100644
--- a/tensorflow/lite/delegates/gpu/common/model_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_test.cc
@@ -15,11 +15,9 @@
 
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 
-#include <initializer_list>
-#include <vector>
-
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/model_transformer.cc b/tensorflow/lite/delegates/gpu/common/model_transformer.cc
index 81287dd..3be7ec5 100644
--- a/tensorflow/lite/delegates/gpu/common/model_transformer.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_transformer.cc
@@ -19,6 +19,7 @@
 #include <string>
 #include <vector>
 
+#include "absl/container/flat_hash_set.h"
 #include "absl/strings/str_join.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 
diff --git a/tensorflow/lite/delegates/gpu/common/model_transformer.h b/tensorflow/lite/delegates/gpu/common/model_transformer.h
index fd26673..b640b14 100644
--- a/tensorflow/lite/delegates/gpu/common/model_transformer.h
+++ b/tensorflow/lite/delegates/gpu/common/model_transformer.h
@@ -18,6 +18,7 @@
 
 #include <deque>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "absl/container/flat_hash_set.h"
diff --git a/tensorflow/lite/delegates/gpu/common/object_reader.cc b/tensorflow/lite/delegates/gpu/common/object_reader.cc
index c837fa0..7346b64 100644
--- a/tensorflow/lite/delegates/gpu/common/object_reader.cc
+++ b/tensorflow/lite/delegates/gpu/common/object_reader.cc
@@ -16,13 +16,18 @@
 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
 
 #include <cstdint>
+#include <optional>
+#include <string>
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/utils.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc
index fbffe9d..bf9ee3c 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.cc
+++ b/tensorflow/lite/delegates/gpu/common/operations.cc
@@ -15,11 +15,17 @@
 
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 
+#include <algorithm>
 #include <cstdint>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
 
 #include "absl/container/flat_hash_map.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h
index 563dbde..c1e7abb 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.h
+++ b/tensorflow/lite/delegates/gpu/common/operations.h
@@ -17,14 +17,15 @@
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OPERATIONS_H_
 
 #include <cstdint>
+#include <set>
 #include <string>
 #include <vector>
 
 #include "absl/types/variant.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
-#include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/quantization_util.cc b/tensorflow/lite/delegates/gpu/common/quantization_util.cc
index fe92989..bbd9902 100644
--- a/tensorflow/lite/delegates/gpu/common/quantization_util.cc
+++ b/tensorflow/lite/delegates/gpu/common/quantization_util.cc
@@ -15,9 +15,15 @@
 
 #include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
 
+#include <stdint.h>
+
+#include <vector>
+
 #include "absl/container/flat_hash_map.h"
-#include "tensorflow/lite/builtin_ops.h"
+#include "absl/status/status.h"
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/common/quantization_util.h b/tensorflow/lite/delegates/gpu/common/quantization_util.h
index fc01d61..584f687 100644
--- a/tensorflow/lite/delegates/gpu/common/quantization_util.h
+++ b/tensorflow/lite/delegates/gpu/common/quantization_util.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
 
+#include <stdint.h>
+
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
diff --git a/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc b/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc
index b5cdaec..ffded54 100644
--- a/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/quantization_util_test.cc
@@ -15,8 +15,18 @@
 
 #include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
 
+#include <stdint.h>
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/util.h"
 
 using ::testing::Eq;
diff --git a/tensorflow/lite/delegates/gpu/common/shape.cc b/tensorflow/lite/delegates/gpu/common/shape.cc
index 074637a..c66ecea 100644
--- a/tensorflow/lite/delegates/gpu/common/shape.cc
+++ b/tensorflow/lite/delegates/gpu/common/shape.cc
@@ -14,6 +14,11 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 
+#include <stdint.h>
+
+#include <string>
+#include <vector>
+
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 
diff --git a/tensorflow/lite/delegates/gpu/common/shape.h b/tensorflow/lite/delegates/gpu/common/shape.h
index 544d2c1..a017ff2 100644
--- a/tensorflow/lite/delegates/gpu/common/shape.h
+++ b/tensorflow/lite/delegates/gpu/common/shape.h
@@ -16,9 +16,9 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_SHAPE_H_
 
-#include <sys/types.h>
+#include <stddef.h>
+#include <stdint.h>
 
-#include <algorithm>
 #include <array>
 #include <functional>
 #include <numeric>
@@ -26,8 +26,6 @@
 #include <utility>
 #include <vector>
 
-#include "absl/hash/hash.h"
-
 namespace tflite {
 namespace gpu {
 
diff --git a/tensorflow/lite/delegates/gpu/common/shape_test.cc b/tensorflow/lite/delegates/gpu/common/shape_test.cc
index 4151911..3cbf1fd 100644
--- a/tensorflow/lite/delegates/gpu/common/shape_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/shape_test.cc
@@ -14,10 +14,10 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 
-#include <initializer_list>
+#include <stdint.h>
+
 #include <vector>
 
-#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/common/tensor.h b/tensorflow/lite/delegates/gpu/common/tensor.h
index fc39d34..054da89 100644
--- a/tensorflow/lite/delegates/gpu/common/tensor.h
+++ b/tensorflow/lite/delegates/gpu/common/tensor.h
@@ -16,7 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TENSOR_H_
 
-#include <string>
+#include <stdint.h>
+
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
diff --git a/tensorflow/lite/delegates/gpu/common/testing/BUILD b/tensorflow/lite/delegates/gpu/common/testing/BUILD
index a7f97eb..33f7d58 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/testing/BUILD
@@ -10,6 +10,8 @@
     hdrs = ["interpreter_utils.h"],
     deps = [
         "//tensorflow/lite:framework",
+        "//tensorflow/lite:string",
+        "//tensorflow/lite/c:common",
         "//tensorflow/lite/core/api",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
@@ -25,13 +27,12 @@
     hdrs = ["tflite_model_reader.h"],
     deps = [
         "//tensorflow/lite:framework_lib",
-        "//tensorflow/lite:kernel_api",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/core/api",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_builder",
+        "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common/transformations:general_transformations",
-        "//tensorflow/lite/kernels:builtin_ops",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD
index b5ceff3..5015096 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD
@@ -24,10 +24,12 @@
     hdrs = ["utils.h"],
     deps = [
         "//tensorflow/lite:framework",
+        "//tensorflow/lite:string",
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:builtin_ops",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
         "@com_google_googletest//:gtest",
     ],
 )
@@ -48,6 +50,7 @@
     deps = [
         ":feature_parity",
         ":utils",
+        "//tensorflow/lite:framework_lib",
         "//tensorflow/lite/delegates/gpu:gl_delegate",
         "@com_google_googletest//:gtest_main",
     ],
@@ -65,6 +68,7 @@
     deps = [
         ":feature_parity",
         ":utils",
+        "//tensorflow/lite:framework_lib",
         "//tensorflow/lite/delegates/gpu:delegate",
         "@com_google_googletest//:gtest_main",
     ],
@@ -82,6 +86,7 @@
     deps = [
         ":feature_parity",
         ":utils",
+        "//tensorflow/lite:framework_lib",
         "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
         "@com_google_googletest//:gtest_main",
     ],
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h
index 7661a4a..dacb486 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h
@@ -16,9 +16,6 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_FEATURE_PARITY_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_FEATURE_PARITY_H_
 
-#include <functional>
-#include <string>
-#include <utility>
 #include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.h"
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD
index 4fef0a2..56894c8 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD
@@ -20,9 +20,7 @@
     srcs = ["add.cc"],
     hdrs = ["add.h"],
     deps = [
-        "//tensorflow/lite:framework",
         "//tensorflow/lite:schema_fbs_version",
-        "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common/testing/feature_parity:utils",
         "@flatbuffers",
     ],
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc
index dbb3851..06649b3 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.cc
@@ -15,11 +15,14 @@
 
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/add.h"
 
+#include <stdint.h>
+
+#include <string>
+#include <utility>
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "flatbuffers/flatbuffers.h"  // from @flatbuffers
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h"
-#include "tensorflow/lite/model.h"
 #include "tensorflow/lite/version.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc
index 24c0e0c..3dbb863 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opencl_test.cc
@@ -13,11 +13,17 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <stdint.h>
+
+#include <memory>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h"
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h"
 #include "tensorflow/lite/delegates/gpu/delegate.h"
+#include "tensorflow/lite/interpreter.h"
 
 namespace tflite {
 
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc
index 2f403d2..ed0aa10 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/opengl_test.cc
@@ -15,12 +15,14 @@
 
 #include <cstdint>
 #include <memory>
+#include <vector>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h"
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h"
 #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
+#include "tensorflow/lite/interpreter.h"
 
 namespace tflite {
 
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc
index bdcbf7e..6eb94f6 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc
@@ -15,15 +15,18 @@
 
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h"
 
+#include <memory>
+#include <optional>
 #include <ostream>
 #include <string>
+#include <utility>
 
 #include "absl/status/status.h"
 #include "absl/strings/substitute.h"
-#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_type.h"
 
 std::ostream& operator<<(std::ostream& os, const TfLiteTensor& tensor) {
   std::string shape;
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h
index 7c34978f..20d43b8 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h
@@ -16,14 +16,24 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TESTING_FEATURE_PARITY_UTILS_H_
 
+#include <stddef.h>
+
 #include <cstdint>
+#include <memory>
+#include <optional>
 #include <ostream>
 #include <string>
+#include <tuple>
+#include <utility>
 #include <vector>
 
 #include <gmock/gmock.h>
+#include <gtest/gtest.h>
 #include "absl/status/status.h"
+#include "absl/types/span.h"
+#include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_type.h"
 
 namespace tflite {
 
diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc
index 3d05d64..bdd1295 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/xnnpack_test.cc
@@ -13,11 +13,17 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <stdint.h>
+
+#include <memory>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/feature_parity.h"
 #include "tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h"
 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
+#include "tensorflow/lite/interpreter.h"
 
 namespace tflite {
 
diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc
index 08d9448..ae00e21 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.cc
@@ -16,15 +16,18 @@
 #include "tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h"
 
 #include <cstring>
+#include <memory>
+#include <string>
 #include <vector>
 
 #include "absl/memory/memory.h"
-#include "tensorflow/lite/context.h"
 #include "tensorflow/lite/core/api/op_resolver.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_type.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h
index ca2825b..86656ab 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h
+++ b/tensorflow/lite/delegates/gpu/common/testing/interpreter_utils.h
@@ -18,7 +18,7 @@
 
 #include <vector>
 
-#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/core/api/op_resolver.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
diff --git a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc
index a67602c..21c6346 100644
--- a/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc
+++ b/tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.cc
@@ -14,16 +14,18 @@
 ==============================================================================*/
 #include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h"
 
+#include <stddef.h>
+
 #include <memory>
 
-#include "tensorflow/lite/builtin_ops.h"
 #include "tensorflow/lite/core/api/op_resolver.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
 #include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/model.h"
 #include "tensorflow/lite/model_builder.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/BUILD b/tensorflow/lite/delegates/gpu/common/transformations/BUILD
index bf26b03..3f5b57d 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/BUILD
+++ b/tensorflow/lite/delegates/gpu/common/transformations/BUILD
@@ -12,9 +12,9 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
         "@com_google_absl//absl/memory",
-        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:any",
     ],
 )
@@ -24,11 +24,11 @@
     srcs = ["add_quant_adjustments.cc"],
     hdrs = ["add_quant_adjustments.h"],
     deps = [
-        "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:any",
@@ -40,10 +40,13 @@
     srcs = ["add_quant_adjustments_test.cc"],
     deps = [
         ":add_quant_adjustments",
+        "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/types:any",
         "@com_google_absl//absl/types:optional",
         "@com_google_googletest//:gtest_main",
@@ -56,9 +59,13 @@
     hdrs = ["fuse_add_to_conv.h"],
     deps = [
         "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -67,8 +74,13 @@
     srcs = ["fuse_add_to_conv_test.cc"],
     deps = [
         ":fuse_add_to_conv",
+        "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
     ],
 )
@@ -82,8 +94,10 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
         "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -92,9 +106,13 @@
     srcs = ["fuse_mul_to_conv_test.cc"],
     deps = [
         ":fuse_mul_to_conv",
+        "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
     ],
 )
@@ -123,7 +141,8 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
-        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:any",
     ],
@@ -138,6 +157,8 @@
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/types:any",
         "@com_google_googletest//:gtest_main",
     ],
@@ -151,8 +172,11 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:any",
     ],
 )
@@ -165,6 +189,9 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/types:any",
         "@com_google_googletest//:gtest_main",
     ],
@@ -173,7 +200,6 @@
 cc_library(
     name = "matching",
     hdrs = ["matching.h"],
-    deps = ["//tensorflow/lite/delegates/gpu/common:model"],
 )
 
 cc_library(
@@ -186,7 +212,9 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:any",
@@ -198,10 +226,13 @@
     srcs = ["merge_padding_with_test.cc"],
     deps = [
         ":merge_padding_with",
+        "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/status",
         "@com_google_absl//absl/types:any",
         "@com_google_googletest//:gtest_main",
     ],
@@ -216,8 +247,11 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -230,6 +264,9 @@
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:model_transformer",
         "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/status",
         "@com_google_googletest//:gtest_main",
     ],
 )
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc
index 29d70d8..4c6b08e 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc
@@ -15,13 +15,18 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
 
+#include <memory>
+#include <string>
+#include <vector>
+
 #include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
 #include "absl/types/any.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc
index 6262d15..7f43d70 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.cc
@@ -15,15 +15,19 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
 
+#include <memory>
+#include <optional>
 #include <string>
+#include <vector>
 
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
 #include "absl/types/any.h"
-#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc
index 2ff8498..e9d31a3 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc
@@ -15,14 +15,20 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
 
-#include <gmock/gmock.h>
+#include <memory>
+#include <string>
+#include <vector>
+
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 #include "absl/types/any.h"
 #include "absl/types/optional.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc
index fdbd6e0..2b432ba 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc
@@ -15,8 +15,20 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
 
+#include <any>
+#include <memory>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include "absl/strings/string_view.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h
index 53a0cef..26f93dc 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h
@@ -20,7 +20,6 @@
 
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc
index 4a48c7c..7b1f595 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc
@@ -15,10 +15,20 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
 
+#include <any>
+#include <memory>
+#include <string>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 using ::testing::FloatNear;
 using ::testing::Pointwise;
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc
index 25ec629..41bd485 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc
@@ -15,9 +15,18 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
 
+#include <any>
+#include <memory>
+#include <string>
+#include <variant>
+#include <vector>
+
+#include "absl/strings/string_view.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h
index 8d64ae5..92fab45 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h
@@ -20,7 +20,6 @@
 
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc
index ea990dd..2a14f96 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv_test.cc
@@ -15,11 +15,20 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
 
+#include <any>
+#include <memory>
+#include <string>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 using ::testing::FloatNear;
 using ::testing::Pointwise;
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc
index f9ae7f4..6d59b0b 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/general_transformations.cc
@@ -15,6 +15,9 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
 
+#include <memory>
+
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.h"
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc
index 1236cde..226e7d4 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc
@@ -15,11 +15,17 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h"
 
+#include <memory>
+#include <string>
+#include <vector>
+
 #include "absl/memory/memory.h"
 #include "absl/types/any.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
-#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc
index d3606d4..2703533 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc
@@ -15,13 +15,18 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h"
 
-#include <gmock/gmock.h>
+#include <memory>
+#include <string>
+#include <vector>
+
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 #include "absl/types/any.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc
index 17aac83..51335a8 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc
@@ -15,11 +15,19 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h"
 
+#include <memory>
+#include <string>
+#include <vector>
+
 #include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/any.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc
index f8be321..5d6bca2 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc
@@ -15,12 +15,18 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h"
 
-#include <gmock/gmock.h>
+#include <memory>
+#include <string>
+#include <vector>
+
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 #include "absl/types/any.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/matching.h b/tensorflow/lite/delegates/gpu/common/transformations/matching.h
index 0dfd21e..b28c8b0 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/matching.h
+++ b/tensorflow/lite/delegates/gpu/common/transformations/matching.h
@@ -18,9 +18,10 @@
 
 // A file provides predicates to match subgraphs.
 
+#include <algorithm>
+#include <iterator>
 #include <string>
-
-#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include <vector>
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc
index 6a4e24b..509d715 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc
@@ -15,16 +15,22 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
 
+#include <memory>
 #include <string>
+#include <variant>
 #include <vector>
 
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "absl/types/any.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/transformations/matching.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc
index 40029ef..e8ff5ed 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with_test.cc
@@ -15,13 +15,19 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
 
-#include <gmock/gmock.h>
+#include <memory>
+#include <string>
+#include <vector>
+
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 #include "absl/types/any.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc
index 6cc3708..a9966d2 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc
@@ -15,14 +15,25 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
 
+#include <algorithm>
+#include <any>
+#include <functional>
+#include <iterator>
+#include <memory>
 #include <string>
+#include <utility>
+#include <variant>
 #include <vector>
 
 #include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc
index a6aafee..0ccaf53 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop_test.cc
@@ -15,12 +15,20 @@
 
 #include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
 
+#include <any>
+#include <memory>
+#include <string>
+#include <vector>
+
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "absl/status/status.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
 
 namespace tflite {
 namespace gpu {
diff --git a/tensorflow/lite/delegates/gpu/common/types.h b/tensorflow/lite/delegates/gpu/common/types.h
index 8725b42..4ddb46f 100644
--- a/tensorflow/lite/delegates/gpu/common/types.h
+++ b/tensorflow/lite/delegates/gpu/common/types.h
@@ -19,7 +19,6 @@
 #include <array>
 #include <cstddef>
 #include <cstdint>
-#include <string>
 
 #include <fp16.h>
 
diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.cc b/tensorflow/lite/delegates/gpu/common/winograd_util.cc
index 16be80e..4b9581d 100644
--- a/tensorflow/lite/delegates/gpu/common/winograd_util.cc
+++ b/tensorflow/lite/delegates/gpu/common/winograd_util.cc
@@ -15,6 +15,9 @@
 
 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
 
+#include <cmath>
+#include <vector>
+
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.h b/tensorflow/lite/delegates/gpu/common/winograd_util.h
index 2e80a6c..e88ceac 100644
--- a/tensorflow/lite/delegates/gpu/common/winograd_util.h
+++ b/tensorflow/lite/delegates/gpu/common/winograd_util.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_WINOGRAD_UTIL_H_
 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_WINOGRAD_UTIL_H_
 
+#include <vector>
+
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc
index 5ae2a53..439eb0a 100644
--- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc
+++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.cc
@@ -15,7 +15,10 @@
 
 #include "tensorflow/lite/delegates/gpu/common/workgroup_selection.h"
 
+#include <math.h>
+
 #include <set>
+#include <vector>
 
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 
diff --git a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h
index a08bfce..67c51b4 100644
--- a/tensorflow/lite/delegates/gpu/common/workgroup_selection.h
+++ b/tensorflow/lite/delegates/gpu/common/workgroup_selection.h
@@ -18,9 +18,6 @@
 
 #include <vector>
 
-#include "tensorflow/lite/delegates/gpu/common/status.h"
-#include "tensorflow/lite/delegates/gpu/common/types.h"
-
 namespace tflite {
 namespace gpu {
 
diff --git a/tensorflow/lite/delegates/gpu/gl/variable.h b/tensorflow/lite/delegates/gpu/gl/variable.h
index 1c5bb26..5237481 100644
--- a/tensorflow/lite/delegates/gpu/gl/variable.h
+++ b/tensorflow/lite/delegates/gpu/gl/variable.h
@@ -18,6 +18,7 @@
 
 #include <array>
 #include <cstdint>
+#include <string>
 #include <vector>
 
 #include "absl/types/variant.h"
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index f4f4c18..e90f8a4 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -122,6 +122,7 @@
     srcs = ["conv.cc"],
     hdrs = ["conv.h"],
     deps = [
+        "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:model",
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
diff --git a/tensorflow/lite/delegates/nnapi/BUILD b/tensorflow/lite/delegates/nnapi/BUILD
index 53ac979..7a34b08 100644
--- a/tensorflow/lite/delegates/nnapi/BUILD
+++ b/tensorflow/lite/delegates/nnapi/BUILD
@@ -1,4 +1,5 @@
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = [
@@ -26,6 +27,7 @@
         "nnapi_delegate.h",
         "nnapi_delegate_kernel.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         "//tensorflow/lite:allocation",
         "//tensorflow/lite:kernel_api",
diff --git a/tensorflow/lite/delegates/utils/dummy_delegate/README.md b/tensorflow/lite/delegates/utils/dummy_delegate/README.md
index ae17f1b..d55ba42 100644
--- a/tensorflow/lite/delegates/utils/dummy_delegate/README.md
+++ b/tensorflow/lite/delegates/utils/dummy_delegate/README.md
@@ -20,11 +20,11 @@
 
 ## Testing & Tooling
 
-There are currently **two optionss** to plug in a newly created TFLite delegate
+There are currently **two options** to plug in a newly created TFLite delegate
 to reuse existing TFLite kernel tests and and tooling:
 
 - Utilize the **[delegate registrar](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/delegates)**
-mechansim
+mechanism
 - Utilize the
 **[external delegate](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/external)**
 mechanism.
@@ -126,13 +126,13 @@
 and tooling**, we first create an external delegate adaptor like the [`external_delegate_adaptor.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/utils/dummy_delegate/external_delegate_adaptor.cc) here, and create the corresponding BUILD target
 to build a dynamic library.
 
-Afterwards, one could build binaries or use pre-built ones that are linked with
-the
+Afterwards, one could build binaries or use pre-built ones to run with the
+dummy delegate as long as the binary is linked with the
 [`external_delegate_provider`](https://github.com/tensorflow/tensorflow/blob/8c6f2d55762f3fc94f98fdd8b3c5d59ee1276dba/tensorflow/lite/tools/delegates/BUILD#L145-L159)
 library which supports command-line flags as described
 [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/delegates#external-delegate-provider).
-Note this delegate provider has already been linked to existing testing and
-tooling binaries.
+Note this external delegate provider has already been linked to existing testing
+and tooling binaries.
 
 For example, the following illustrates how to benchmark the dummy delegate here
 via this external-delegate approach. We could use similar commands for testing
diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD
index 3c580ed..f9825f6 100644
--- a/tensorflow/lite/delegates/xnnpack/BUILD
+++ b/tensorflow/lite/delegates/xnnpack/BUILD
@@ -1,4 +1,5 @@
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = ["//visibility:public"],
@@ -38,6 +39,7 @@
 cc_library(
     name = "xnnpack_delegate_hdrs_only",
     hdrs = ["xnnpack_delegate.h"],
+    compatible_with = get_compatible_with_portable(),
     visibility = ["//tensorflow/lite:__subpackages__"],
     deps = [
         "//tensorflow/lite/c:common",
diff --git a/tensorflow/lite/experimental/delegates/coreml/BUILD b/tensorflow/lite/experimental/delegates/coreml/BUILD
index 2985cd3..ee20970 100644
--- a/tensorflow/lite/experimental/delegates/coreml/BUILD
+++ b/tensorflow/lite/experimental/delegates/coreml/BUILD
@@ -59,6 +59,9 @@
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/delegates:utils",
         "//tensorflow/lite/experimental/delegates/coreml/builders:op_builder",
+        "//tensorflow/lite/experimental/delegates/coreml/builders:op_validator",
+        "//tensorflow/lite/experimental/delegates/coreml/builders:util",
+        "//tensorflow/lite/kernels:kernel_util",
     ],
 )
 
diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc
index c775f4f..bf8fb33 100644
--- a/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc
+++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc
@@ -46,6 +46,8 @@
       return AddBuilder(CreateLogisticOpBuilder, node);
     case kTfLiteBuiltinMaxPool2d:
       return AddBuilder(CreateMaxPool2dOpBuilder, node);
+    case kTfLiteBuiltinMean:
+      return AddBuilder(CreateMeanOpBuilder, node);
     case kTfLiteBuiltinMirrorPad:
       return AddBuilder(CreateMirrorPadOpBuilder, node);
     case kTfLiteBuiltinMul:
diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h b/tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h
index 4245021..c70dbf2 100644
--- a/tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h
+++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h
@@ -32,6 +32,7 @@
 OpBuilder* CreateHardSwishOpBuilder(GraphBuilder* graph_builder);
 OpBuilder* CreateLogisticOpBuilder(GraphBuilder* graph_builder);
 OpBuilder* CreateMaxPool2dOpBuilder(GraphBuilder* graph_builder);
+OpBuilder* CreateMeanOpBuilder(GraphBuilder* graph_builder);
 OpBuilder* CreateMirrorPadOpBuilder(GraphBuilder* graph_builder);
 OpBuilder* CreateMulOpBuilder(GraphBuilder* graph_builder);
 // PAD handles PAD and PADV2 together.
diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h b/tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h
index b099fd7..a97a0f3 100644
--- a/tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h
+++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_validator.h
@@ -31,6 +31,8 @@
 bool IsFullyConnectedOpSupported(const TfLiteRegistration* registration,
                                  const TfLiteNode* node,
                                  TfLiteContext* context);
+bool IsMeanOpSupported(const TfLiteRegistration* registration,
+                       const TfLiteNode* node, TfLiteContext* context);
 bool IsMirrorPadOpSupported(const TfLiteRegistration* registration,
                             const TfLiteNode* node, TfLiteContext* context);
 bool IsPadOpSupported(const TfLiteRegistration* registration,
diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.cc
index 8859639..d3e3f6b 100644
--- a/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.cc
+++ b/tensorflow/lite/experimental/delegates/coreml/builders/pooling_layer_builder.cc
@@ -18,6 +18,8 @@
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
 
 namespace tflite {
 namespace delegates {
@@ -29,13 +31,15 @@
     case kTfLiteBuiltinAveragePool2d:
       GetDebugName("PoolingLayerBuilder (AVERAGE)", node_id_, str_debug_name_);
       break;
-
     case kTfLiteBuiltinMaxPool2d:
       GetDebugName("PoolingLayerBuilder (MAX)", node_id_, str_debug_name_);
       break;
     case kTfLiteBuiltinL2Pool2d:
-      GetDebugName("PoolingLayerBuilder (L2, unsupported)",
-                   node_id_, str_debug_name_);
+      GetDebugName("PoolingLayerBuilder (L2, unsupported)", node_id_,
+                   str_debug_name_);
+      break;
+    case kTfLiteBuiltinMean:
+      GetDebugName("PoolingLayerBuilder (MEAN)", node_id_, str_debug_name_);
       break;
     default:
       GetDebugName("PoolingLayerBuilder (ERROR)", node_id_, str_debug_name_);
@@ -44,13 +48,18 @@
 }
 
 CoreML::Specification::NeuralNetworkLayer* PoolingLayerBuilder::Build() {
-  if (layer_ == nullptr) {
-    layer_.reset(new CoreML::Specification::NeuralNetworkLayer);
-  }
   layer_->set_name(DebugName());
+  auto* pooling_params = layer_->mutable_pooling();
+
+  if (pooling_type_ == kTfLiteBuiltinMean) {
+    pooling_params->set_type(
+        CoreML::Specification::PoolingLayerParams::AVERAGE);
+    pooling_params->set_globalpooling(true);
+    return layer_.release();
+  }
+
   const TfLitePoolParams* params =
       reinterpret_cast<const TfLitePoolParams*>(builtin_data_);
-  auto* pooling_params = layer_->mutable_pooling();
   pooling_params->mutable_stride()->Add(params->stride_height);
   pooling_params->mutable_stride()->Add(params->stride_width);
   pooling_params->mutable_kernelsize()->Add(params->filter_height);
@@ -89,7 +98,12 @@
 
 TfLiteStatus PoolingLayerBuilder::RegisterInputs(const TfLiteIntArray* inputs,
                                                  TfLiteContext* context) {
-  if (inputs->size != 1) {
+  if (pooling_type_ == kTfLiteBuiltinMean) {
+    if (inputs->size != 2) {
+      TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Mean!.");
+      return kTfLiteError;
+    }
+  } else if (inputs->size != 1) {
     TF_LITE_KERNEL_LOG(context, "Wrong # of inputs to Pooling!.");
     return kTfLiteError;
   }
@@ -115,6 +129,38 @@
   return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMaxPool2d);
 }
 
+OpBuilder* CreateMeanOpBuilder(GraphBuilder* graph_builder) {
+  return new PoolingLayerBuilder(graph_builder, kTfLiteBuiltinMean);
+}
+
+// Only supports averaging over H and W dimensions, as
+bool IsMeanOpSupported(const TfLiteRegistration* registration,
+                       const TfLiteNode* node, TfLiteContext* context) {
+  const TfLiteTensor* input = GetInput(context, node, 0);
+  const TfLiteTensor* axis = GetInput(context, node, 1);
+  const auto* params =
+      reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
+
+  if (!params->keep_dims) {
+    TF_LITE_KERNEL_LOG(context, "keep_dims should be true for Mean op.");
+    return false;
+  }
+  if (input->dims->size != 4) {
+    TF_LITE_KERNEL_LOG(context, "Mean op is only supported for 4D input.");
+    return false;
+  }
+  const int* axis_data = GetTensorData<int>(axis);
+  std::vector<bool> axis_mask = {false, true, true, false};
+  for (int i = 0; i < axis->dims->data[0]; ++i) {
+    if (!axis_mask[(axis_data[i] + 4) % 4]) {
+      TF_LITE_KERNEL_LOG(context,
+                         "Mean op should reduce for H and W dimensions.");
+      return false;
+    }
+  }
+  return true;
+}
+
 }  // namespace coreml
 }  // namespace delegates
 }  // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.mm b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.mm
index 2cca58a..173546b 100644
--- a/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.mm
+++ b/tensorflow/lite/experimental/delegates/coreml/coreml_delegate.mm
@@ -109,6 +109,9 @@
     case kTfLiteBuiltinMirrorPad: {
       return delegates::coreml::IsMirrorPadOpSupported(registration, node, context);
     }
+    case kTfLiteBuiltinMean: {
+      return delegates::coreml::IsMeanOpSupported(registration, node, context);
+    }
     case kTfLiteBuiltinMul: {
       return node->builtin_data != nullptr &&
              delegates::coreml::IsBinaryOpSupported(registration, node, context);
diff --git a/tensorflow/lite/experimental/ios/hide_symbols_with_allowlist.sh b/tensorflow/lite/experimental/ios/hide_symbols_with_allowlist.sh
index 6841643..27253cd 100755
--- a/tensorflow/lite/experimental/ios/hide_symbols_with_allowlist.sh
+++ b/tensorflow/lite/experimental/ios/hide_symbols_with_allowlist.sh
@@ -33,8 +33,7 @@
 # LD_DEBUGGABLE_FLAGS="-d"
 
 # Exits if C++ symbols are found in the allowlist.
-if grep -q "^__Z" "${ALLOWLIST_FILE_PATH}"
-then
+if grep -q "^__Z" "${ALLOWLIST_FILE_PATH}"; then
   echo "ERROR: Failed in symbol hiding. This rule does not permit hiding of" \
        "C++ symbols due to possible serious problems mixing symbol hiding," \
        "shared libraries and the C++ runtime." \
@@ -59,8 +58,7 @@
 merge_cmd=(xcrun lipo)
 
 # Merges object files and hide symbols for each architecture.
-for arch in "${archs[@]}"
-do
+for arch in "${archs[@]}"; do
     archdir=$(mktemp -t "${arch}" -d)
     arch_file="${archdir}/${arch}"
 
diff --git a/tensorflow/lite/experimental/resource/BUILD b/tensorflow/lite/experimental/resource/BUILD
index fce783d..b1b53af 100644
--- a/tensorflow/lite/experimental/resource/BUILD
+++ b/tensorflow/lite/experimental/resource/BUILD
@@ -1,3 +1,5 @@
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
+
 package(
     default_visibility = ["//visibility:public"],
     licenses = ["notice"],  # Apache 2.0
@@ -16,6 +18,7 @@
         "resource_variable.h",
         "static_hashtable.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         "//tensorflow/lite:string_util",
         "//tensorflow/lite/c:common",
diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml
index 4729ac8..097a11a 100644
--- a/tensorflow/lite/g3doc/_book.yaml
+++ b/tensorflow/lite/g3doc/_book.yaml
@@ -76,7 +76,7 @@
         path: /lite/guide/roadmap
 
       - heading: "Convert a model"
-      - title: "TensorFlow Lite converter"
+      - title: "Overview"
         path: /lite/convert/
       - title: "Python API"
         path: /lite/convert/python_api
@@ -88,10 +88,10 @@
         path: /lite/convert/rnn
       - title: "Add metadata"
         path: /lite/convert/metadata
-      - title: "Composite operation fusion"
-        path: /lite/convert/operation_fusion
-      - title: "1.x compatibility"
-        path: /lite/convert/1x_compatibility
+      - title: "Sample models"
+        path: /lite/guide/hosted_models
+      - title: "API updates"
+        path: /lite/convert/api_updates
 
       - heading: "Create a model"
       - title: "TensorFlow Lite Model Maker"
@@ -101,27 +101,48 @@
       - heading: "Inference"
       - title: "Overview"
         path: /lite/guide/inference
-      - title: "Integrate models with metadata"
-        path: /lite/guide/codegen
-      - title: "Custom operators"
-        path: /lite/guide/ops_custom
-      - title: "Operator versions"
-        path: /lite/guide/ops_version
       - title: "Operator compatibility"
         path: /lite/guide/ops_compatibility
-      - title: "Select operators from TensorFlow"
+      - title: "Select operators"
         path: /lite/guide/ops_select
+      - title: "Custom operators"
+        path: /lite/guide/ops_custom
+      - title: "Fused operators"
+        path: /lite/convert/operation_fusion
+      - title: "Operator versions"
+        path: /lite/guide/ops_version
         status: experimental
-      - title: "Process input and output data"
-        path: /lite/guide/lite_support
-      - title: "List of hosted models"
-        path: /lite/guide/hosted_models
+
+      - heading: "Inference with metadata"
+      - title: "Overview"
+        path: /lite/inference_with_metadata/overview
+      - title: "Generate model interfaces with codegen"
+        path: /lite/inference_with_metadata/codegen
+      - title: "Integrate models with Task Library"
+        path: /lite/inference_with_metadata/task_library/overview
+        section:
+        - title: "ImageClassifier"
+          path: /lite/inference_with_metadata/task_library/image_classifier
+        - title: "ObjectDetector"
+          path: /lite/inference_with_metadata/task_library/object_detector
+        - title: "ImageSegmenter"
+          path: /lite/inference_with_metadata/task_library/image_segmenter
+        - title: "NLClassifier"
+          path: /lite/inference_with_metadata/task_library/nl_classifier
+        - title: "BertNLClassifier"
+          path: /lite/inference_with_metadata/task_library/bert_nl_classifier
+        - title: "BertQuestionAnswerer"
+          path: /lite/inference_with_metadata/task_library/bert_question_answerer
+        - title: "Customized API"
+          path: /lite/inference_with_metadata/task_library/customized_task_api
+      - title: "Customize input and output data processing"
+        path: /lite/inference_with_metadata/lite_support
 
       - heading: "Performance"
       - title: "Best practices"
         path: /lite/performance/best_practices
-      - title: "Benchmarks"
-        path: /lite/performance/benchmarks
+      - title: "Measurement"
+        path: /lite/performance/measurement
       - title: "Delegates"
         path: /lite/performance/delegates
         status: experimental
@@ -209,7 +230,7 @@
     - name: "API"
       skip_translation: true
       contents:
-      - title: API Reference
+      - title: "API Reference"
         path: /lite/api_docs/
       - heading: "Python"
       - title: "Overview"
@@ -218,7 +239,7 @@
       - heading: "Android (Java)"
       - include: /lite/api_docs/java/_toc.yaml
       - heading: "C++"
-      - title: Overview
+      - title: "Overview"
         path: /lite/api_docs/cc/
       - include: /lite/api_docs/cc/_doxygen.yaml
 
diff --git a/tensorflow/lite/g3doc/convert/1x_compatibility.md b/tensorflow/lite/g3doc/convert/1x_compatibility.md
deleted file mode 100644
index ceb99ba..0000000
--- a/tensorflow/lite/g3doc/convert/1x_compatibility.md
+++ /dev/null
@@ -1,120 +0,0 @@
-# TensorFlow 1.x Compatibility <a name="differences"></a>
-
-The `tf.lite.TFLiteConverter` Python API was updated between TensorFlow 1.x and
-2.x. This document explains the differences between the two versions, and
-provides information about how to use the 1.x version if required.
-
-If any of the changes raise concerns, please file a
-[GitHub Issue](https://github.com/tensorflow/tensorflow/issues).
-
-Note: We highly recommend that you
-[migrate your TensorFlow 1.x code to TensorFlow 2.x code](https://www.tensorflow.org/guide/migrate)
-.
-
-## Model formats
-
-#### SavedModel and Keras
-
-The `tf.lite.TFLiteConverter` API supports SavedModel and Keras HDF5 files
-generated in both TensorFlow 1.x and 2.x.
-
-#### Frozen Graph
-
-Note: TensorFlow 2.x no longer supports the generation of frozen graph models.
-
-The `tf.compat.v1.lite.TFLiteConverter` API supports frozen graph models
-generated in TensorFlow 1.x, as shown below:
-
-```python
-import tensorflow as tf
-# Path to the frozen graph file
-graph_def_file = 'frozen_graph.pb'
-# A list of the names of the model's input tensors
-input_arrays = ['input_name']
-# A list of the names of the model's output tensors
-output_arrays = ['output_name']
-# Load and convert the frozen graph
-converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
-  graph_def_file, input_arrays, output_arrays)
-tflite_model = converter.convert()
-# Write the converted model to disk
-open("converted_model.tflite", "wb").write(tflite_model)
-```
-
-## Converter attributes
-
-#### Renamed attributes
-
-The following 1.x attribute has been renamed in 2.x.
-
-*   `target_ops` has been renamed to `target_spec.supported_ops` - In 2.x, in
-    line with future additions to the optimization framework, it has become an
-    attribute of `TargetSpec` and has been renamed to `supported_ops`.
-
-#### Unsupported attributes
-
-The following 1.x attributes have been removed in 2.x.
-
-*   _Quantization_ - In 2.x,
-    [quantize aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training)
-    is supported through the Keras API and
-    [post training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)
-    uses fewer streamlined converter flags. Thus, the following attributes and
-    methods related to quantization have been removed:
-    *   `inference_type`
-    *   `quantized_input_stats`
-    *   `post_training_quantize`
-    *   `default_ranges_stats`
-    *   `reorder_across_fake_quant`
-    *   `change_concat_input_ranges`
-    *   `get_input_arrays()`
-*   _Visualization_ - In 2.x, the recommended approach for visualizing a
-    TensorFlow Lite graph is to use
-    [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py)
-    . Unlike GraphViz, it enables users to visualize the graph after post
-    training quantization has occurred. Thus, the following attributes related
-    to graph visualization have been removed:
-    *   `output_format`
-    *   `dump_graphviz_dir`
-    *   `dump_graphviz_video`
-*   _Frozen graph_ - In 2.x, the frozen graph model format has been removed.
-    Thus, the following attribute related to frozen graphs has been removed:
-    *   `drop_control_dependency`
-
-## Unsupported APIs
-
-The following section explains several significant features in 1.x that have
-been removed in 2.x.
-
-#### Conversion APIs
-
-The following methods were deprecated in 1.x and have been removed in 2.x:
-
-*   `lite.toco_convert`
-*   `lite.TocoConverter`
-
-#### `lite.constants` API
-
-The `lite.constants` API was removed in 2.x in order to decrease duplication
-between TensorFlow and TensorFlow Lite. The following list maps the
-`lite.constant` type to the TensorFlow type:
-
-*   `lite.constants.FLOAT`: `tf.float32`
-*   `lite.constants.INT8`: `tf.int8`
-*   `lite.constants.INT32`: `tf.int32`
-*   `lite.constants.INT64`: `tf.int64`
-*   `lite.constants.STRING`: `tf.string`
-*   `lite.constants.QUANTIZED_UINT8`: `tf.uint8`
-
-Additionally, the deprecation of the `output_format` flag in `TFLiteConverter`
-led to the removal of the following constants:
-
-*   `lite.constants.TFLITE`
-*   `lite.constants.GRAPHVIZ_DOT`
-
-#### `lite.OpHint` API
-
-The `OpHint` API is currently unsupported due to an incompatibility with the 2.x
-APIs. This API enables conversion of LSTM based models. Support for LSTMs in 2.x
-is being investigated. All related `lite.experimental` APIs have been removed
-due to this issue.
diff --git a/tensorflow/lite/g3doc/convert/api_updates.md b/tensorflow/lite/g3doc/convert/api_updates.md
new file mode 100644
index 0000000..a990b4f
--- /dev/null
+++ b/tensorflow/lite/g3doc/convert/api_updates.md
@@ -0,0 +1,48 @@
+# API Updates <a name="api_updates"></a>
+
+This page provides information about updates made to the
+`tf.lite.TFLiteConverter` [Python API](index.md) in TensorFlow 2.x.
+
+Note: If any of the changes raise concerns, please file a
+[GitHub issue](https://github.com/tensorflow/tensorflow/issues/new?template=60-tflite-converter-issue.md).
+
+*   TensorFlow 2.3
+
+    *   Support integer (previously, only float) input/output type for integer
+        quantized models using the new `inference_input_type` and
+        `inference_output_type` attributes. Refer to this
+        [example usage](../performance/post_training_quantization.md#integer_only).
+    *   Support conversion and resizing of models with dynamic dimensions.
+    *   Added a new experimental quantization mode with 16-bit activations and
+        8-bit weights.
+
+*   TensorFlow 2.2
+
+    *   By default, leverage [MLIR-based conversion](https://mlir.llvm.org/),
+        Google's cutting edge compiler technology for machine learning. This
+        enables conversion of new classes of models, including Mask R-CNN,
+        Mobile BERT, etc and supports models with functional control flow.
+
+*   TensorFlow 2.0 vs TensorFlow 1.x
+
+    *   Renamed the `target_ops` attribute to `target_spec.supported_ops`
+    *   Removed the following attributes:
+        *   _quantization_: `inference_type`, `quantized_input_stats`,
+            `post_training_quantize`, `default_ranges_stats`,
+            `reorder_across_fake_quant`, `change_concat_input_ranges`,
+            `get_input_arrays()`. Instead,
+            [quantize aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training)
+            is supported through the `tf.keras` API and
+            [post training quantization](../performance/post_training_quantization.md)
+            uses fewer attributes.
+        *   _visualization_: `output_format`, `dump_graphviz_dir`,
+            `dump_graphviz_video`. Instead, the recommended approach for
+            visualizing a TensorFlow Lite model is to use
+            [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py).
+        *   _frozen graphs_: `drop_control_dependency`, as frozen graphs are
+            unsupported in TensorFlow 2.x.
+    *   Removed other converter APIs such as `tf.lite.toco_convert` and
+        `tf.lite.TocoConverter`
+    *   Removed other related APIs such as `tf.lite.OpHint` and
+        `tf.lite.constants` (the `tf.lite.constants.*` types have been mapped to
+        `tf.*` TensorFlow data types, to reduce duplication)
diff --git a/tensorflow/lite/g3doc/convert/metadata.md b/tensorflow/lite/g3doc/convert/metadata.md
index 4279e40..667e12f 100644
--- a/tensorflow/lite/g3doc/convert/metadata.md
+++ b/tensorflow/lite/g3doc/convert/metadata.md
@@ -7,9 +7,9 @@
 *   human readable parts which convey the best practice when using the model,
     and
 *   machine readable parts that can be leveraged by code generators, such as the
-    [TensorFlow Lite Android code generator](../guide/codegen.md#generate-code-with-tensorflow-lite-android-code-generator)
+    [TensorFlow Lite Android code generator](../inference_with_metadata/codegen.md#generate-code-with-tensorflow-lite-android-code-generator)
     and the
-    [Android Studio ML Binding feature](../guide/codegen.md#generate-code-with-android-studio-ml-model-binding).
+    [Android Studio ML Binding feature](../inference_with_metadata/codegen.md#generate-code-with-android-studio-ml-model-binding).
 
 All image models published on
 [TensorFlow Lite hosted models](https://www.tensorflow.org/lite/guide/hosted_models)
@@ -47,9 +47,9 @@
     [SubGraphMetadata.output_tensor_metadata](https://github.com/tensorflow/tflite-support/blob/4cd0551658b6e26030e0ba7fc4d3127152e0d4ae/tensorflow_lite_support/metadata/metadata_schema.fbs#L599).
 
 Since TensorFlow Lite only supports single subgraph at this point, the
-[TensorFlow Lite code generator](../guide/codegen.md#generate-code-with-tensorflow-lite-android-code-generator)
+[TensorFlow Lite code generator](../inference_with_metadata/codegen.md#generate-code-with-tensorflow-lite-android-code-generator)
 and the
-[Android Studio ML Binding feature](../guide/codegen.md#generate-code-with-android-studio-ml-model-binding)
+[Android Studio ML Binding feature](../inference_with_metadata/codegen.md#generate-code-with-android-studio-ml-model-binding)
 will use `ModelMetadata.name` and `ModelMetadata.description`, instead of
 `SubGraphMetadata.name` and `SubGraphMetadata.description`, when displaying
 metadata and generating code.
@@ -82,11 +82,11 @@
 [Pack mtadata and associated files into the model](#pack-metadata-and-associated-files-into-the-model)
 for more details.
 
-The associate file information can be recored in the metadata. Depending on the
+The associated file information can be recored in the metadata. Depending on the
 file type and where the file is attached to (i.e. `ModelMetadata`,
 `SubGraphMetadata`, and `TensorMetadata`),
-[the TensorFlow Lite Android code generator](../guide/codegen.md) may apply
-corresponding pre/post processing automatically to the object. See
+[the TensorFlow Lite Android code generator](../inference_with_metadata/codegen.md)
+may apply corresponding pre/post processing automatically to the object. See
 [the \<Codegen usage\> section of each associate file type](https://github.com/tensorflow/tflite-support/blob/4cd0551658b6e26030e0ba7fc4d3127152e0d4ae/tensorflow_lite_support/metadata/metadata_schema.fbs#L77-L127)
 in the schema for more details.
 
@@ -161,8 +161,7 @@
 and the
 [TensorFlow Lite C++ API](https://github.com/tensorflow/tensorflow/blob/09ec15539eece57b257ce9074918282d88523d56/tensorflow/lite/c/common.h#L391).
 \
-[2] The
-[metadata extractor library](../guide/codegen.md#read-the-metadata-from-models)
+[2] The [metadata extractor library](#read-the-metadata-from-models)
 
 When processing image data for uint8 models, normalization and quantization are
 sometimes skipped. It is fine to do so when the pixel values are in the range of
@@ -348,6 +347,9 @@
   f.write(json_file)
 ```
 
+Android Studio also supports displaying metadata through the
+[Android Studio ML Binding feature](https://developer.android.com/studio/preview/features#tensor-flow-lite-models).
+
 ## Metadata versioning
 
 The
@@ -391,5 +393,94 @@
 smallest compatible version indicated by the file identifier. The minimum
 necessary metadata parser version is automatically populated by the
 `MetadataPopulator` when the metadata is populated into a TFLite model. See the
-[metadata extractor](../guide/codegen.md#read-the-metadata-from-models) about
-how the minimum necessary metadata parser version is used.
+[metadata extractor](#read-the-metadata-from-models) for more information on how
+the minimum necessary metadata parser version is used.
+
+## Read the metadata from models
+
+The Metadata Extractor library is convenient tool to read the metadata and
+associated files from a models across different platforms (see the
+[Java version](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/metadata/java)
+and the
+[C++ version](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/metadata/cc)).
+You can build your own metadata extractor tool in other languages using the
+Flatbuffers library.
+
+### Read the metadata in Java
+
+To use the Metadata Extractor library in your Android app, we recommend using
+the
+[TensorFlow Lite Metadata AAR hosted at JCenter](https://bintray.com/google/tensorflow/tensorflow-lite-metadata).
+It contains the `MetadataExtractor` class, as well as the FlatBuffers Java
+bindings for the
+[metadata schema](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/metadata_schema.fbs)
+and the
+[model schema](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs).
+
+You can specify this in your `build.gradle` dependencies as follows:
+
+```build
+dependencies {
+    implementation 'org.tensorflow:tensorflow-lite-metadata:0.0.0-nightly'
+}
+```
+
+You can initialize a `MetadataExtractor` object with a `ByteBuffer` that points
+to the model:
+
+```java
+public MetadataExtractor(ByteBuffer buffer);
+```
+
+The `ByteBuffer` must remain unchanged for the entire lifetime of the
+`MetadataExtractor` object. The initialization may fail if the Flatbuffers file
+identifier of the model metadata does not match that of the metadata parser. See
+[metadata versioning](#metadata-versioning) for more information.
+
+With matching file identifiers, the metadata extractor will successfully read
+metadata generated from all past and future schema due to the Flatbuffers'
+forwards and backwards compatibility mechanism. However, fields from future
+schemas cannot be extracted by older metadata extractors. The
+[minimum necessary parser version](#the-minimum-necessary-metadata-parser-version)
+of the metadata indicates the minimum version of metadata parser that can read
+the metadata Flatbuffers in full. You can use the following method to verify if
+the minimum necessary parser version condition is met:
+
+```java
+public final boolean isMinimumParserVersionSatisfied();
+```
+
+Passing in a model without metadata is allowed. However, invoking methods that
+read from the metadata will cause runtime errors. You can check if a model has
+metadata by invoking the `hasMetadata` method:
+
+```java
+public boolean hasMetadata();
+```
+
+`MetadataExtractor` provides convenient functions for you to get the
+input/output tensors' metadata. For example,
+
+```java
+public int getInputTensorCount();
+public TensorMetadata getInputTensorMetadata(int inputIndex);
+public QuantizationParams getInputTensorQuantizationParams(int inputIndex);
+public int[] getInputTensorShape(int inputIndex);
+public int getoutputTensorCount();
+public TensorMetadata getoutputTensorMetadata(int inputIndex);
+public QuantizationParams getoutputTensorQuantizationParams(int inputIndex);
+public int[] getoutputTensorShape(int inputIndex);
+```
+
+You can also read associated files through their names with the
+`getAssociatedFile` method:
+
+```java
+public InputStream getAssociatedFile(String fileName);
+```
+
+Though the
+[TensorFlow Lite model schema](https://github.com/tensorflow/tensorflow/blob/aa7ff6aa28977826e7acae379e82da22482b2bf2/tensorflow/lite/schema/schema.fbs#L1075)
+supports multiple subgraphs, the TFLite Interpreter currently only supports a
+single subgraph. Therefore, `MetadataExtractor` omits subgraph index as an input
+argument in its methods.
diff --git a/tensorflow/lite/g3doc/guide/android.md b/tensorflow/lite/g3doc/guide/android.md
index 26885de..420269d 100644
--- a/tensorflow/lite/g3doc/guide/android.md
+++ b/tensorflow/lite/g3doc/guide/android.md
@@ -16,7 +16,7 @@
 The application can run either on device or emulator.
 
 Inference is performed using the TensorFlow Lite Java API and the
-[TensorFlow Lite Android Support Library](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/README.md).
+[TensorFlow Lite Android Support Library](../inference_with_metadata/lite_support.md).
 The demo app classifies frames in real-time, displaying the top most probable
 classifications. It allows the user to choose between a floating point or
 [quantized](https://www.tensorflow.org/lite/performance/post_training_quantization)
@@ -41,6 +41,36 @@
 The following sections contain some useful information for working with
 TensorFlow Lite on Android.
 
+### Use the TensorFlow Lite Task Library
+
+TensorFlow Lite Task Library contains a set of powerful and easy-to-use
+task-specific libraries for app developers to create ML experiences with TFLite.
+It provides optimized out-of-box model interfaces for popular machine learning
+tasks, such as image classification, question and answer, etc. The model
+interfaces are specifically designed for each task to achieve the best
+performance and usability. Task Library works cross-platform and is supported on
+Java, C++, and Swift (coming soon).
+
+To use the Support Library in your Android app, we recommend using the AAR
+hosted at JCenter for
+[Task Vision library](https://bintray.com/google/tensorflow/tensorflow-lite-task-vision)
+and
+[Task Text library](https://bintray.com/google/tensorflow/tensorflow-lite-task-text)
+, respectively.
+
+You can specify this in your `build.gradle` dependencies as follows:
+
+```build
+dependencies {
+    implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
+    implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
+}
+```
+
+See the introduction in the
+[TensorFlow Lite Task Library overview](../inference_with_metadata/task_library/overview.md)
+for more details.
+
 ### Use the TensorFlow Lite Android Support Library
 
 The TensorFlow Lite Android Support Library makes it easier to integrate models
@@ -52,8 +82,19 @@
 arrays. It also provides pre- and post-processing units that perform tasks such
 as image resizing and cropping.
 
+To use the Support Library in your Android app, we recommend using the
+[TensorFlow Lite Support Library AAR hosted at JCenter](https://bintray.com/google/tensorflow/tensorflow-lite-support).
+
+You can specify this in your `build.gradle` dependencies as follows:
+
+```build
+dependencies {
+    implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
+}
+```
+
 To get started, follow the instructions in the
-[TensorFlow Lite Android Support Library README.md](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/README.md).
+[TensorFlow Lite Android Support Library](../inference_with_metadata/lite_support.md).
 
 ### Use the TensorFlow Lite AAR from JCenter
 
diff --git a/tensorflow/lite/g3doc/guide/codegen.md b/tensorflow/lite/g3doc/guide/codegen.md
deleted file mode 100644
index 84dd2ff..0000000
--- a/tensorflow/lite/g3doc/guide/codegen.md
+++ /dev/null
@@ -1,243 +0,0 @@
-# Integrate TensorFlow Lite models with metadata
-
-[TensorFlow Lite metadata](../convert/metadata.md) contains a rich description
-of what the model does and how to use the model. It can empower code generators,
-such as the
-[TensorFlow Lite Android code generator](#generate-code-with-tensorflow-lite-android-code-generator)
-and the
-[Android Studio ML Binding feature](#generate-code-with-android-studio-ml-model-binding),
-to automatically generates the inference code for you. It can also be used to
-configure your custom inference pipeline.
-
-Browse
-[TensorFlow Lite hosted models](https://www.tensorflow.org/lite/guide/hosted_models)
-and [TensorFlow Hub](https://tfhub.dev/s?deployment-format=lite) to download
-pretrained models with metadata. All image models have been supported.
-
-## Generate code with TensorFlow Lite Android code generator
-
-Note: TensorFlow Lite wrapper code generator currently only supports Android.
-
-For TensorFlow Lite model enhanced with [metadata](../convert/metadata.md),
-developers can use the TensorFlow Lite Android wrapper code generator to create
-platform specific wrapper code. The wrapper code removes the need to interact
-directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
-Lite model with typed objects such as `Bitmap` and `Rect`.
-
-The usefulness of the code generator depend on the completeness of the
-TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
-under relevant fields in
-[metadata_schema.fbs](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/metadata_schema.fbs),
-to see how the codegen tool parses each field.
-
-### Generate Wrapper Code
-
-You will need to install the following tooling in your terminal:
-
-```sh
-pip install tflite-support
-```
-
-Once completed, the code generator can be used using the following syntax:
-
-```sh
-tflite_codegen --model=./model_with_metadata/mobilenet_v1_0.75_160_quantized.tflite \
-    --package_name=org.tensorflow.lite.classify \
-    --model_class_name=MyClassifierModel \
-    --destination=./classify_wrapper
-```
-
-The resulting code will be located in the destination directory. If you are
-using [Google Colab](https://colab.research.google.com/) or other remote
-environment, it maybe easier to zip up the result in a zip archive and download
-it to your Android Studio project:
-
-```python
-## Zip up the generated code
-!zip -r classify_wrapper.zip classify_wrapper/
-
-## Kick off the download
-from google.colab import files
-files.download('classify_wrapper.zip')
-```
-
-### Using the generated code
-
-#### Step 1: Import the generated code
-
-Unzip the generated code if necessary into a directory structure. The root of
-the generated code is assumed to be `SRC_ROOT`.
-
-Open the Android Studio project where you would like to use the TensorFlow lite
-model and import the generated module by: And File -> New -> Import Module ->
-select `SRC_ROOT`
-
-Using the above example, the directory and the module imported would be called
-`classify_wrapper`.
-
-#### Step 2: Update the app's `build.gradle` file
-
-In the app module that will be consuming the generated library module:
-
-Under the android section, add the following:
-
-```build
-aaptOptions {
-   noCompress "tflite"
-}
-```
-
-Under the dependencies section, add the following:
-
-```build
-implementation project(":classify_wrapper")
-```
-
-#### Step 3: Using the model
-
-```java
-// 1. Initialize the model
-MyClassifierModel myImageClassifier = null;
-
-try {
-    myImageClassifier = new MyClassifierModel(this);
-} catch (IOException io){
-    // Error reading the model
-}
-
-if(null != myImageClassifier) {
-
-    // 2. Set the input with a Bitmap called inputBitmap
-    MyClassifierModel.Inputs inputs = myImageClassifier.createInputs();
-    inputs.loadImage(inputBitmap));
-
-    // 3. Run the model
-    MyClassifierModel.Outputs outputs = myImageClassifier.run(inputs);
-
-    // 4. Retrieve the result
-    Map<String, Float> labeledProbability = outputs.getProbability();
-}
-```
-
-### Accelerating model inference
-
-The generated code provides a way for developers to accelerate their code
-through the use of [delegates](../performance/delegates.md) and the number of
-threads. These can be set when initiatizing the model object as it takes three
-parameters:
-
-*   **`Context`**: Context from the Android Activity or Service
-*   (Optional) **`Device`**: TFLite acceleration delegate for example
-    GPUDelegate or NNAPIDelegate
-*   (Optional) **`numThreads`**: Number of threads used to run the model -
-    default is one.
-
-For example, to use a NNAPI delegate and up to three threads, you can initialize
-the model like this:
-
-```java
-try {
-    myImageClassifier = new MyClassifierModel(this, Model.Device.NNAPI, 3);
-} catch (IOException io){
-    // Error reading the model
-}
-```
-
-### Troubleshooting
-
-#### Getting 'java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed'
-
-Under the app module that will uses the library module, insert the following
-lines under the android section:
-
-```build
-aaptOptions {
-   noCompress "tflite"
-}
-```
-
-## Generate code with Android Studio ML Model Binding
-
-[Android Studio ML Model Binding](https://developer.android.com/studio/preview/features#tensor-flow-lite-models)
-allows you to directly import TensorFlow Lite models and use them in your
-Android Studio projects. It generates easy-to-use classes so you can run your
-model with less code and better type safety. See the
-[introduction](https://developer.android.com/studio/preview/features#tensor-flow-lite-models)
-for more details.
-
-Note: Code generated by the TensorFlow Lite Android code generator may include
-some latest API or experimental features, which can be a super set of the one
-generated by the Android Studio ML Model Binding.
-
-## Read the metadata from models
-
-The Metadata Extractor library is a convenient tool to read the metadata and
-associated files from a models across different platforms (see the
-[Java version](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/metadata)
-and the C++ version is coming soon). Users can also build their own metadata
-extractor tool in other languages using the Flatbuffers library.
-
-### Read the metadata in Java
-
-Note: the Java Metadata Extractor library is available as an Android library
-dependency: `org.tensorflow:tensorflow-lite-metadata`.
-
-You can initialize a `MetadataExtractor` with a `ByteBuffer` that points to the
-model:
-
-```java
-public MetadataExtractor(ByteBuffer buffer);
-```
-
-The `ByteBuffer` must remain unchanged for the whole lifetime of the
-`MetadataExtractor`. The initialization may fail if the Flatbuffers file
-identifier of the model metadata does not match the one of the metadata parser.
-See [metadata versioning](../convert/metadata.md#metadata-versioning) for more
-information.
-
-As long as the file identifer is satisfied, the metadata extractor will not fail
-when reading metadata generated from an old or a future scheme due to the
-Flatbuffers forward and backwards compatibility mechanism. But fields from
-future schemas cannot be extracted by older metadata extractors. The
-[minimum necessary parser version](../convert/metadata.md#the-minimum-necessary-metadata-parser-version)
-of the metadata indicates the minimum version of metadata parser that can read
-the metadata Flatbuffers in full. You can use the following method to verify if
-the minimum necessary parser version is satisfied:
-
-```java
-public final boolean isMinimumParserVersionSatisfied();
-```
-
-It is allowed to pass in a model without metadata. However, invoking methods
-that read from the metadata will cause runtime errors. You can check if a model
-has metadata by invoking the method:
-
-```java
-public boolean hasMetadata();
-```
-
-`MetadataExtractor` provides convenient functions for you to get the
-input/output tensors' metadata. For example,
-
-```java
-public int getInputTensorCount();
-public TensorMetadata getInputTensorMetadata(int inputIndex);
-public QuantizationParams getInputTensorQuantizationParams(int inputIndex);
-public int[] getInputTensorShape(int inputIndex);
-public int getoutputTensorCount();
-public TensorMetadata getoutputTensorMetadata(int inputIndex);
-public QuantizationParams getoutputTensorQuantizationParams(int inputIndex);
-public int[] getoutputTensorShape(int inputIndex);
-```
-
-You can also read associated files through their names with the method:
-
-```java
-public InputStream getAssociatedFile(String fileName);
-```
-
-Though the
-[TensorFlow Lite model schema](https://github.com/tensorflow/tensorflow/blob/aa7ff6aa28977826e7acae379e82da22482b2bf2/tensorflow/lite/schema/schema.fbs#L1075)
-supports multiple subgraphs, the TFLite Interpreter only supports single
-subgraph so far. Therefore, `MetadataExtractor` omits subgraph index as an input
-in its methods.
diff --git a/tensorflow/lite/g3doc/guide/hosted_models.md b/tensorflow/lite/g3doc/guide/hosted_models.md
index a97be10..32887a5 100644
--- a/tensorflow/lite/g3doc/guide/hosted_models.md
+++ b/tensorflow/lite/g3doc/guide/hosted_models.md
@@ -4,8 +4,8 @@
 TensorFlow Lite.
 
 To get started choosing a model, visit <a href="../models">Models</a> page with
-end-to-end examples, or pick a [TensorFlow Lite model from TensorFlow Hub]
-(https://tfhub.dev/s?deployment-format=lite).
+end-to-end examples, or pick a
+[TensorFlow Lite model from TensorFlow Hub](https://tfhub.dev/s?deployment-format=lite).
 
 Note: The best model for a given application depends on your requirements. For
 example, some applications might benefit from higher accuracy, while others
@@ -16,6 +16,9 @@
 
 For more information about image classification, see
 <a href="../models/image_classification/overview.md">Image classification</a>.
+Explore the TensorFlow Lite Task Library for instructions about
+[how to integrate image classification models](../inference_with_metadata/task_library/image_classifier)
+in just a few lines of code.
 
 ### Quantized models
 
@@ -24,7 +27,8 @@
 the expense of accuracy. The performance values are measured on Pixel 3 on
 Android 10.
 
-You can find many [quantized models](https://tfhub.dev/s?deployment-format=lite&module-type=image-classification&q=quantized)
+You can find many
+[quantized models](https://tfhub.dev/s?deployment-format=lite&module-type=image-classification&q=quantized)
 from TensorFlow Hub and get more model information there.
 
 Model name                  | Paper and model                                                                                                                                                                   | Model size | Top-1 accuracy | Top-5 accuracy | CPU, 4 threads | NNAPI
@@ -54,8 +58,8 @@
 Note: The model files include both TF Lite FlatBuffer and Tensorflow frozen
 Graph.
 
-Note: Performance numbers were benchmarked on Pixel-3 (Android 10).
-Accuracy numbers were computed using the
+Note: Performance numbers were benchmarked on Pixel-3 (Android 10). Accuracy
+numbers were computed using the
 [TFLite image classification evaluation tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification).
 
 ### Floating point models
@@ -65,7 +69,8 @@
 of floating point models. The performance values are measured on Pixel 3 on
 Android 10.
 
-You can find many [image classification models](https://tfhub.dev/s?deployment-format=lite&module-type=image-classification)
+You can find many
+[image classification models](https://tfhub.dev/s?deployment-format=lite&module-type=image-classification)
 from TensorFlow Hub and get more model information there.
 
 Model name            | Paper and model                                                                                                                                                                           | Model size | Top-1 accuracy | Top-5 accuracy | CPU, 4 threads | GPU    | NNAPI
@@ -102,8 +107,9 @@
 <a href="https://cloud.google.com/automl/">Cloud AutoML</a>. The performance
 values are measured on Pixel 3 on Android 10.
 
-You can find these models in [TensorFlow Hub](https://tfhub.dev/s?deployment-format=lite&q=MnasNet)
-and get more model information there.
+You can find these models in
+[TensorFlow Hub](https://tfhub.dev/s?deployment-format=lite&q=MnasNet) and get
+more model information there.
 
 Model Name       | Paper and model                                                                                                                                                | Model size | Top-1 accuracy | Top-5 accuracy | CPU, 4 threads | GPU     | NNAPI
 ---------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | -------------: | ------: | ----:
@@ -116,16 +122,20 @@
 MnasNet_1.0_224  | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz)  | 17 Mb      | 74.08%         | 91.75%         | 19.4 ms        | 8.7 ms  | 19 ms
 MnasNet_1.3_224  | [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz)  | 24 Mb      | 75.24%         | 92.55%         | 27.9 ms        | 10.6 ms | 22.0 ms
 
-Note: Performance numbers were benchmarked on Pixel-3 (Android 10).
-Accuracy numbers were computed using the
+Note: Performance numbers were benchmarked on Pixel-3 (Android 10). Accuracy
+numbers were computed using the
 [TFLite image classification evaluation tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification).
 
 ## Object detection
 
 For more information about object detection, see
-<a href="../models/object_detection/overview.md">Object detection</a>.
+<a href="../models/object_detection/overview.md">Object detection</a>. Explore
+the TensorFlow Lite Task Library for instructions about
+[how to integrate object detection models](../inference_with_metadata/task_library/object_detector)
+in just a few lines of code.
 
-Please find [object detection models](https://tfhub.dev/s?deployment-format=lite&module-type=image-object-detection)
+Please find
+[object detection models](https://tfhub.dev/s?deployment-format=lite&module-type=image-object-detection)
 from TensorFlow Hub.
 
 ## Pose estimation
@@ -133,21 +143,29 @@
 For more information about pose estimation, see
 <a href="../models/pose_estimation/overview.md">Pose estimation</a>.
 
-Please find [pose estimation models](https://tfhub.dev/s?deployment-format=lite&module-type=image-pose-detection)
+Please find
+[pose estimation models](https://tfhub.dev/s?deployment-format=lite&module-type=image-pose-detection)
 from TensorFlow Hub.
 
 ## Image segmentation
 
 For more information about image segmentation, see
-<a href="../models/segmentation/overview.md">Segmentation</a>.
+<a href="../models/segmentation/overview.md">Segmentation</a>. Explore the
+TensorFlow Lite Task Library for instructions about
+[how to integrate image segmentation models](../inference_with_metadata/task_library/image_segmenter)
+in just a few lines of code.
 
-Please find [image segmentation models](https://tfhub.dev/s?deployment-format=lite&module-type=image-segmentation)
+Please find
+[image segmentation models](https://tfhub.dev/s?deployment-format=lite&module-type=image-segmentation)
 from TensorFlow Hub.
 
 ## Question and Answer
 
-For more information about text classification with Mobile BERT, see
-<a href="../models/bert_qa/overview.md">Question And Answer</a>.
+For more information about question and answer with MobileBERT, see
+<a href="../models/bert_qa/overview.md">Question And Answer</a>. Explore the
+TensorFlow Lite Task Library for instructions about
+[how to integrate question and answer models](../inference_with_metadata/task_library/bert_question_answerer)
+in just a few lines of code.
 
 Please find [Mobile BERT model](https://tfhub.dev/tensorflow/mobilebert/1) from
 TensorFlow Hub.
diff --git a/tensorflow/lite/g3doc/guide/inference.md b/tensorflow/lite/g3doc/guide/inference.md
index fbf03ab..9b3ebf4 100644
--- a/tensorflow/lite/g3doc/guide/inference.md
+++ b/tensorflow/lite/g3doc/guide/inference.md
@@ -84,7 +84,7 @@
 directly with `ByteBuffer` on Android. Instead, developers can interact with the
 TensorFlow Lite model with typed objects such as `Bitmap` and `Rect`. For more
 information, please refer to the
-[TensorFlow Lite Android wrapper code generator](codegen.md).
+[TensorFlow Lite Android wrapper code generator](../inference_with_metadata/codegen.md).
 
 ### iOS
 
diff --git a/tensorflow/lite/g3doc/guide/ops_select.md b/tensorflow/lite/g3doc/guide/ops_select.md
index 5aa3e96..b9e5b34 100644
--- a/tensorflow/lite/g3doc/guide/ops_select.md
+++ b/tensorflow/lite/g3doc/guide/ops_select.md
@@ -21,6 +21,10 @@
 
 Models converted with TensorFlow ops will require a TensorFlow Lite interpreter
 that has a larger binary size than the interpreter with only TFLite builtin ops.
+For Android, It is possible to reduce binary size by selectively linking only
+required Tensorflow ops. For the details, please see the
+[Reduce TensorFlow Lite binary size](../guide/reduce_binary_size.md) section.
+
 Additionally, performance optimizations will not be available for any TensorFlow
 ops in the TensorFlow Lite model.
 
@@ -66,7 +70,7 @@
 ```
 
 The following example shows how to use this feature in the
-[`tflite_convert`](../convert/cmdline_examples.md) command line tool using the
+[`tflite_convert`](../convert/cmdline.md) command line tool using the
 command line flag `target_ops`.
 
 ```sh
@@ -98,8 +102,10 @@
 
 ### Android AAR
 
-For Android, we recommend using the prebuilt [AAR with TensorFlow ops hosted at
-JCenter](https://bintray.com/google/tensorflow/tensorflow-lite-select-tf-ops).
+To reduce the binary size, please build your own custom AAR files as guided in
+the [next section](#building-the-android-aar). If the binary size is not a
+considerable concern, we recommend using the prebuilt
+[AAR with TensorFlow ops hosted at JCenter](https://bintray.com/google/tensorflow/tensorflow-lite-select-tf-ops).
 
 You can specify this in your `build.gradle` dependencies by adding it alongside
 the standard TensorFlow Lite AAR as follows:
@@ -112,9 +118,9 @@
 }
 ```
 
-Once you've added the dependency, the necessary delegate for handling
-the graph's TensorFlow ops should be automatically installed for
-graphs that require them.
+Once you've added the dependency, the necessary delegate for handling the
+graph's TensorFlow ops should be automatically installed for graphs that require
+them.
 
 *Note*: The TensorFlow ops dependency is relatively large, so you'll probably
 want to filter out unnecessary x86 ABIs in your `.gradle` file by setting up
@@ -132,23 +138,32 @@
 
 #### Building the Android AAR
 
-For more advanced cases, you can also build the library manually. Assuming a
-<a href="android.md">working TensorFlow Lite build environment</a>, build the
-Android AAR with select TensorFlow ops as follows:
+For reducing the binary size or other advanced cases, you can also build the
+library manually. Assuming a <a href="android.md">working TensorFlow Lite build
+environment</a>, build the Android AAR with select TensorFlow ops as follows:
 
 ```sh
-bazel build --cxxopt='--std=c++14' -c opt   \
-  --config=android_arm --config=monolithic  \
-  //tensorflow/lite/java:tensorflow-lite-select-tf-ops
+sh tensorflow/lite/tools/build_aar.sh \
+  --input_models=/a/b/model_one.tflite,/c/d/model_two.tflite \
+  --target_archs=x86,x86_64,arm64-v8a,armeabi-v7a
 ```
 
-This will generate an AAR file in `bazel-bin/tensorflow/lite/java/`. From there,
-you can either import the AAR directly into your project, or publish the custom
-AAR to your local Maven repository:
+This will generate the AAR file `bazel-bin/tmp/tensorflow-lite.aar` for
+TensorFlow Lite built-in and custom ops; and generate the AAR file
+`bazel-bin/tmp/tensorflow-lite-select-tf-ops.aar` for TensorFlow ops. If you
+don't have a working build environment, You can also
+[build above files with docker](../guide/reduce_binary_size.md#selectively_build_tensorflow_lite_with_docker).
+
+From there, you can either import the AAR files directly into your project, or
+publish the custom AAR files to your local Maven repository:
 
 ```sh
 mvn install:install-file \
-  -Dfile=bazel-bin/tensorflow/lite/java/tensorflow-lite-select-tf-ops.aar \
+  -Dfile=bazel-bin/tmp/tensorflow-lite.aar \
+  -DgroupId=org.tensorflow \
+  -DartifactId=tensorflow-lite -Dversion=0.1.100 -Dpackaging=aar
+mvn install:install-file \
+  -Dfile=bazel-bin/tmp/tensorflow-lite-select-tf-ops.aar \
   -DgroupId=org.tensorflow \
   -DartifactId=tensorflow-lite-select-tf-ops -Dversion=0.1.100 -Dpackaging=aar
 ```
@@ -166,7 +181,8 @@
 }
 
 dependencies {
-    implementation 'org.tensorflow:tensorflow-lite-with-select-tf-ops:0.1.100'
+    implementation 'org.tensorflow:tensorflow-lite:0.1.100'
+    implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.1.100'
 }
 ```
 
@@ -285,10 +301,16 @@
 The following table describes the binary size of TensorFlow Lite for each build.
 These targets were built for Android using `--config=android_arm -c opt`.
 
-Build                 | C++ Binary Size | Android APK Size
---------------------- | --------------- | ----------------
-Only built-in ops     | 796 KB          | 561 KB
-Built-in ops + TF ops | 23.0 MB         | 8.0 MB
+Build                     | C++ Binary Size | Android APK Size
+------------------------- | --------------- | ----------------
+Only built-in ops         | 796 KB          | 561 KB
+Built-in ops + TF ops     | 23.0 MB         | 8.0 MB
+Built-in ops + TF ops (1) | 4.1 MB          | 1.8 MB
+
+(1) These libraries are selectively built for
+[i3d-kinetics-400 model](https://tfhub.dev/deepmind/i3d-kinetics-400/1) with 8
+TFLite builtin ops and 3 Tensorflow ops. For more details, please see the
+[Reduce TensorFlow Lite binary size](../guide/reduce_binary_size.md) section.
 
 ## Known limitations
 
@@ -309,10 +331,6 @@
 
 The following is a list of improvements to this pipeline that are in progress:
 
-*   *Selective registration* - There is work being done to make it simple to
-    generate TFLite interpreter binaries that only contain the TensorFlow ops
-    required for a particular set of models.
-*   *Improved usability* - The conversion process will be simplified to only
-    require a single pass through the converter.
 *   *Improved performance* - Work is being done to ensure TensorFlow Lite with
-    TensorFlow ops has performance parity to TensorFlow Mobile.
+    TensorFlow ops nicely cooperates with hardware accelerated delegates, for
+    example, NNAPI and GPU delegates.
diff --git a/tensorflow/lite/g3doc/guide/roadmap.md b/tensorflow/lite/g3doc/guide/roadmap.md
index b762db1..7adb2d1 100644
--- a/tensorflow/lite/g3doc/guide/roadmap.md
+++ b/tensorflow/lite/g3doc/guide/roadmap.md
@@ -37,6 +37,13 @@
 *   **More models and examples**
     *   More examples to demonstrate model usage as well as new features and
         APIs, covering different platforms.
+*   **Task Library**
+    *   Improve the usability of the C++ Task Library, such as providing
+        prebuilt binaries and creating user-friendly workflows for users who
+        want to build from source code.
+    *   Release reference examples of using the Task Library.
+    *   Enable more task types.
+    *   Improve cross-platform support and enable more tasks for iOS.
 
 ## Performance
 
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/codegen.md b/tensorflow/lite/g3doc/inference_with_metadata/codegen.md
new file mode 100644
index 0000000..b447573
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/codegen.md
@@ -0,0 +1,153 @@
+# Generate model interfaces with TensorFlow Lite code generator
+
+Note: TensorFlow Lite wrapper code generator currently only supports Android.
+
+For TensorFlow Lite model enhanced with [metadata](../convert/metadata.md),
+developers can use the TensorFlow Lite Android wrapper code generator to create
+platform specific wrapper code. The wrapper code removes the need to interact
+directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
+Lite model with typed objects such as `Bitmap` and `Rect`.
+
+The usefulness of the code generator depend on the completeness of the
+TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
+under relevant fields in
+[metadata_schema.fbs](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/metadata_schema.fbs),
+to see how the codegen tool parses each field.
+
+## Generate wrapper Code
+
+You will need to install the following tooling in your terminal:
+
+```sh
+pip install tflite-support
+```
+
+Once completed, the code generator can be used using the following syntax:
+
+```sh
+tflite_codegen --model=./model_with_metadata/mobilenet_v1_0.75_160_quantized.tflite \
+    --package_name=org.tensorflow.lite.classify \
+    --model_class_name=MyClassifierModel \
+    --destination=./classify_wrapper
+```
+
+The resulting code will be located in the destination directory. If you are
+using [Google Colab](https://colab.research.google.com/) or other remote
+environment, it maybe easier to zip up the result in a zip archive and download
+it to your Android Studio project:
+
+```python
+# Zip up the generated code
+!zip -r classify_wrapper.zip classify_wrapper/
+
+# Download the archive
+from google.colab import files
+files.download('classify_wrapper.zip')
+```
+
+## Using the generated code
+
+### Step 1: Import the generated code
+
+Unzip the generated code if necessary into a directory structure. The root of
+the generated code is assumed to be `SRC_ROOT`.
+
+Open the Android Studio project where you would like to use the TensorFlow lite
+model and import the generated module by: And File -> New -> Import Module ->
+select `SRC_ROOT`
+
+Using the above example, the directory and the module imported would be called
+`classify_wrapper`.
+
+### Step 2: Update the app's `build.gradle` file
+
+In the app module that will be consuming the generated library module:
+
+Under the android section, add the following:
+
+```build
+aaptOptions {
+   noCompress "tflite"
+}
+```
+
+Under the dependencies section, add the following:
+
+```build
+implementation project(":classify_wrapper")
+```
+
+### Step 3: Using the model
+
+```java
+// 1. Initialize the model
+MyClassifierModel myImageClassifier = null;
+
+try {
+    myImageClassifier = new MyClassifierModel(this);
+} catch (IOException io){
+    // Error reading the model
+}
+
+if(null != myImageClassifier) {
+
+    // 2. Set the input with a Bitmap called inputBitmap
+    MyClassifierModel.Inputs inputs = myImageClassifier.createInputs();
+    inputs.loadImage(inputBitmap));
+
+    // 3. Run the model
+    MyClassifierModel.Outputs outputs = myImageClassifier.run(inputs);
+
+    // 4. Retrieve the result
+    Map<String, Float> labeledProbability = outputs.getProbability();
+}
+```
+
+## Accelerating model inference
+
+The generated code provides a way for developers to accelerate their code
+through the use of [delegates](../performance/delegates.md) and the number of
+threads. These can be set when initiatizing the model object as it takes three
+parameters:
+
+*   **`Context`**: Context from the Android Activity or Service
+*   (Optional) **`Device`**: TFLite acceleration delegate for example
+    GPUDelegate or NNAPIDelegate
+*   (Optional) **`numThreads`**: Number of threads used to run the model -
+    default is one.
+
+For example, to use a NNAPI delegate and up to three threads, you can initialize
+the model like this:
+
+```java
+try {
+    myImageClassifier = new MyClassifierModel(this, Model.Device.NNAPI, 3);
+} catch (IOException io){
+    // Error reading the model
+}
+```
+
+## Troubleshooting
+
+If you get a 'java.io.FileNotFoundException: This file can not be opened as a
+file descriptor; it is probably compressed' error, insert the following lines
+under the android section of the app module that will uses the library module:
+
+```build
+aaptOptions {
+   noCompress "tflite"
+}
+```
+
+## Generate code with Android Studio ML Model Binding
+
+[Android Studio ML Model Binding](https://developer.android.com/studio/preview/features#tensor-flow-lite-models)
+allows you to directly import TensorFlow Lite models and use them in your
+Android Studio projects. It generates easy-to-use classes so you can run your
+model with less code and better type safety. See the
+[introduction](https://developer.android.com/studio/preview/features#tensor-flow-lite-models)
+for more details.
+
+Note: Code generated by the TensorFlow Lite Android code generator may include
+some latest API or experimental features, which can be a super set of the one
+generated by the Android Studio ML Model Binding.
diff --git a/tensorflow/lite/g3doc/guide/lite_support.md b/tensorflow/lite/g3doc/inference_with_metadata/lite_support.md
similarity index 95%
rename from tensorflow/lite/g3doc/guide/lite_support.md
rename to tensorflow/lite/g3doc/inference_with_metadata/lite_support.md
index 39eeeee..ce7b9c0 100644
--- a/tensorflow/lite/g3doc/guide/lite_support.md
+++ b/tensorflow/lite/g3doc/inference_with_metadata/lite_support.md
@@ -40,6 +40,10 @@
 }
 ```
 
+Explore the
+[TensorFlow Lite Support Library AAR hosted at JCenter](https://bintray.com/google/tensorflow/tensorflow-lite-support)
+for different versions of the Support Library.
+
 ### Basic image manipulation and conversion
 
 The TensorFlow Lite Support Library has a suite of basic image manipulation
@@ -72,7 +76,7 @@
 ```
 
 `DataType` of a tensor can be read through the
-[metadata exractor library](../guide/codegen.md#read-the-metadata-from-models)
+[metadata exractor library](../convert/metadata.md#read-the-metadata-from-models)
 as well as other model information.
 
 ### Create output objects and run the model
@@ -235,4 +239,4 @@
 ```
 
 The quantization parameters of a tensor can be read through the
-[metadata exractor library](../guide/codegen.md#read-the-metadata-from-models).
+[metadata exractor library](../convert/metadata.md#read-the-metadata-from-models).
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/overview.md b/tensorflow/lite/g3doc/inference_with_metadata/overview.md
new file mode 100644
index 0000000..8caa92a
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/overview.md
@@ -0,0 +1,51 @@
+# TensorFlow Lite inference with metadata
+
+Inferencing [models with metadata](../convert/metadata.md) can be as easy as
+just a few lines of code. TensorFlow Lite metadata contains a rich description
+of what the model does and how to use the model. It can empower code generators
+to automatically generate the inference code for you, such as using the
+[TensorFlow Lite Android code generator](codegen.md#generate-code-with-tensorflow-lite-android-code-generator)
+and the
+[Android Studio ML Binding feature](codegen.md#generate-code-with-android-studio-ml-model-binding).
+It can also be used to configure your custom inference pipeline.
+
+## Tools and libraries
+
+TensorFlow Lite provides varieties of tools and libraries to serve different
+tiers of deployment requirements as follows:
+
+### Generate model interface with the TensorFlow Lite Code Generator
+
+[TensorFlow Lite Code Generator](codegen.md) is an executable that generates
+model interface automatically based on the metadata. It currently supports
+Android with Java. The wrapper code removes the need to interact directly with
+`ByteBuffer`. Instead, developers can interact with the TensorFlow Lite model
+with typed objects such as `Bitmap` and `Rect`. Android Studio users can also
+get access to the codegen feature through
+[Android Studio ML Binding](codegen.md#generate-code-with-android-studio-ml-model-binding).
+
+### Leverage out-of-box APIs with the TensorFlow Lite Task Library
+
+[TensorFlow Lite Task Library](task_library/overview.md) provides optimized
+ready-to-use model interfaces for popular machine learning tasks, such as image
+classification, question and answer, etc. The model interfaces are specifically
+designed for each task to achieve the best performance and usability. Task
+Library works cross-platform and is supported on Java, C++, and Swift.
+
+### Build custom inference pipelines with the TensorFlow Lite Support Library
+
+[TensorFlow Lite Support Library](lite_support.md) is a cross-platform library
+that helps to customize model interface and build inference pipelines. It
+contains varieties of util methods and data structures to perform pre/post
+processing and data conversion. It is also designed to match the behavior of
+TensorFlow modules, such as TF.Image and TF.Text, ensuring consistency from
+training to inferencing.
+
+## Explore pretrained models with metadata
+
+Browse
+[TensorFlow Lite hosted models](https://www.tensorflow.org/lite/guide/hosted_models)
+and [TensorFlow Hub](https://tfhub.dev/s?deployment-format=lite) to download
+pretrained models with metadata for both vision and text tasks. Also see
+different options of
+[visualizing the metadata](../convert/metadata.md#visualize-the-metadata).
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md
new file mode 100644
index 0000000..41fc958
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md
@@ -0,0 +1,122 @@
+# Integrate BERT natural language classifier
+
+The Task Library `BertNLClassifier` API is very similar to the `NLClassifier`
+that classifies input text into different categories, except that this API is
+specially tailored for Bert related models that require Wordpiece and
+Sentencepiece tokenizations outside the TFLite model.
+
+## Key features of the BertNLClassifier API
+
+*   Takes a single string as input, performs classification with the string and
+    outputs <Label, Score> pairs as classification results.
+
+*   Performs out-of-graph
+    [Wordpiece](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h)
+    or
+    [Sentencepiece](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h)
+    tokenizations on input text.
+
+## Supported BertNLClassifier models
+
+The following models are compatible with the `BertNLClassifier` API.
+
+*   Bert Models created by
+    [TensorFlow Lite Model Maker for text Classfication](https://www.tensorflow.org/lite/tutorials/model_maker_text_classification).
+
+*   Custom models that meet the
+    [model compatibility requirements](#model-compatibility-requirements).
+
+## Run inference in Java
+
+### Step 1: Import Gradle dependency and other settings
+
+Copy the `.tflite` model file to the assets directory of the Android module
+where the model will be run. Specify that the file should not be compressed, and
+add the TensorFlow Lite library to the module’s `build.gradle` file:
+
+```java
+android {
+    // Other settings
+
+    // Specify tflite file should not be compressed for the app apk
+    aaptOptions {
+        noCompress "tflite"
+    }
+
+}
+
+dependencies {
+    // Other dependencies
+
+    // Import the Task Text Library dependency
+    implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
+}
+```
+
+### Step 2: Run inference using the API
+
+```java
+// Initialization
+BertNLClassifier classifier = BertNLClassifier.createFromFile(context, modelFile);
+
+// Run inference
+List<Category> results = classifier.classify(input);
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java)
+for more details.
+
+## Run inference in C++
+
+Note: We are working on improving the usability of the C++ Task Library, such as
+providing prebuilt binaries and creating user-friendly workflows to build from
+source code. The C++ API may be subject to change.
+
+```c++
+// Initialization
+std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromFile(model_path).value();
+
+// Run inference
+std::vector<core::Category> categories = classifier->Classify(kInput);
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h)
+for more details.
+
+## Example results
+
+Here is an example of the classification results of movie reviews using the
+[MobileBert](https://www.tensorflow.org/lite/tutorials/model_maker_text_classification)
+model from Model Maker.
+
+Input: "it's a charming and often affecting journey"
+
+Output:
+
+```
+category[0]: 'negative' : '0.00006'
+category[1]: 'positive' : '0.99994'
+```
+
+Try out the simple
+[CLI demo tool for BertNLClassifier](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/text/desktop/README.md#bertnlclassifier)
+with your own model and test data.
+
+## Model compatibility requirements
+
+The `BetNLClassifier` API expects a TFLite model with mandatory
+[TFLite Model Metadata](../../convert/metadata.md).
+
+The Metadata should meet the following requiresments:
+
+*   input_process_units for Wordpiece/Sentencepiece Tokenizer
+
+*   3 input tensors with names "ids", "mask" and "segment_ids" for the output of
+    the tokenizer
+
+*   1 output tensor of type float32, with a optionally attached label file. If a
+    label file is attached, the file should be a plain text file with one label
+    per line and the number of labels should match the number of categories as
+    the model outputs.
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md
new file mode 100644
index 0000000..9c41f23
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md
@@ -0,0 +1,136 @@
+# Integrate BERT question answerer
+
+The Task Library `BertQuestionAnswerer` API loads a Bert model and answers
+questions based on the content of a given passage. For more information, see the
+documentation for the Question-Answer model
+<a href="../../models/bert_qa/overview.md">here</a>.
+
+## Key features of the BertQuestionAnswerer API
+
+*   Takes two text inputs as question and context and outputs a list of possible
+    answers.
+
+*   Performs out-of-graph Wordpiece or Sentencepiece tokenizations on input
+    text.
+
+## Supported BertQuestionAnswerer models
+
+The following models are compatible with the `BertNLClassifier` API.
+
+*   Models created by
+    [TensorFlow Lite Model Maker for Question Answer](https://www.tensorflow.org/lite/tutorials/model_maker_question_answer).
+
+*   The
+    [pretrained BERT models on TensorFlow Hub](https://tfhub.dev/tensorflow/collections/lite/task-library/bert-question-answerer/1).
+
+*   Custom models that meet the
+    [model compatibility requirements](#model-compatibility-requirements).
+
+## Run inference in Java
+
+### Step 1: Import Gradle dependency and other settings
+
+Copy the `.tflite` model file to the assets directory of the Android module
+where the model will be run. Specify that the file should not be compressed, and
+add the TensorFlow Lite library to the module’s `build.gradle` file:
+
+```java
+android {
+    // Other settings
+
+    // Specify tflite file should not be compressed for the app apk
+    aaptOptions {
+        noCompress "tflite"
+    }
+
+}
+
+dependencies {
+    // Other dependencies
+
+    // Import the Task Text Library dependency
+    implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
+}
+```
+
+### Step 2: Run inference using the API
+
+```java
+// Initialization
+BertQuestionAnswerer answerer = BertQuestionAnswerer.createFromFile(androidContext, modelFile);
+
+// Run inference
+List<QaAnswer> answers = answerer.answer(contextOfTheQuestion, questionToAsk);
+);
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java)
+for more details.
+
+## Run inference in C++
+
+Note: we are working on improving the usability of the C++ Task Library, such as
+providing prebuilt binaries and creating user-friendly workflows to build from
+source code. The C++ API may be subject to change.
+
+```c++
+// Initialization
+std::unique_ptr<BertQuestionAnswerer> answerer = BertQuestionAnswerer::CreateFromFile(model_file).value();
+
+// Run inference
+std::vector<QaAnswer> positive_results = answerer->Answer(context_of_question, question_to_ask);
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h)
+for more details.
+
+## Example results
+
+Here is an example of the answer results of
+[ALBERT model](https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1).
+
+Context: "The Amazon rainforest, alternatively, the Amazon Jungle, also known in
+English as Amazonia, is a moist broadleaf tropical rainforest in the Amazon
+biome that covers most of the Amazon basin of South America. This basin
+encompasses 7,000,000 km2 (2,700,000 sq mi), of which
+5,500,000 km2 (2,100,000 sq mi) are covered by the rainforest. This region
+includes territory belonging to nine nations."
+
+Question: "Where is Amazon rainforest?"
+
+Answers:
+
+```
+answer[0]:  'South America.'
+logit: 1.84847, start_index: 39, end_index: 40
+answer[1]:  'most of the Amazon basin of South America.'
+logit: 1.2921, start_index: 34, end_index: 40
+answer[2]:  'the Amazon basin of South America.'
+logit: -0.0959535, start_index: 36, end_index: 40
+answer[3]:  'the Amazon biome that covers most of the Amazon basin of South America.'
+logit: -0.498558, start_index: 28, end_index: 40
+answer[4]:  'Amazon basin of South America.'
+logit: -0.774266, start_index: 37, end_index: 40
+
+```
+
+Try out the simple
+[CLI demo tool for BertQuestionAnswerer](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/text/desktop/README.md#bert-question-answerer)
+with your own model and test data.
+
+## Model compatibility requirements
+
+The `BertQuestionAnswerer` API expects a TFLite model with mandatory
+[TFLite Model Metadata](../../convert/metadata.md).
+
+The Metadata should meet the following requiresments:
+
+*   `input_process_units` for Wordpiece/Sentencepiece Tokenizer
+
+*   3 input tensors with names "ids", "mask" and "segment_ids" for the output of
+    the tokenizer
+
+*   2 output tensors with names "end_logits" and "start_logits" to indicate the
+    answer's relative position in the context
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/customized_task_api.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/customized_task_api.md
new file mode 100644
index 0000000..68e701d
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/customized_task_api.md
@@ -0,0 +1,448 @@
+# Build you own Task API
+
+<a href="overview.md">TensorFlow Lite Task Library</a> provides prebuilt
+native/Android/iOS APIs on top of the same infrastructure that abstracts
+TensorFlow. You can extend the Task API infrastructure to build customized APIs
+if your model is not supported by existing Task libraries.
+
+## Overview
+
+Task API infrastructure has a two-layer structure: the bottom C++ layer
+encapsulating the native TFLite runtime and the top Java/ObjC layer that
+communicates with the C++ layer through JNI or native wrapper.
+
+Implementing all the TensorFlow logic in only C++ minimizes cost, maximizes
+inference performance and simplifies the overall workflow across platforms.
+
+To create a Task class, extend the
+[BaseTaskApi](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/base_task_api.h)
+to provide conversion logic between TFLite model interface and Task API
+interface, then use the Java/ObjC utilities to create corresponding APIs. With
+all TensorFlow details hidden, you can deploy the TFLite model in your apps
+without any machine learning knowledge.
+
+TensorFlow Lite provides some prebuilt APIs for most popular
+<a href="overview.md#supported_tasks">Vision and NLP tasks</a>. You can build
+your own APIs for other tasks using the Task API infrastructure.
+
+<div align="center">![prebuilt_task_apis](images/prebuilt_task_apis.svg)
+<div align="center">Figure 1. prebuilt Task APIs
+<div align="left">
+
+## Build your own API with Task API infra
+
+### C++ API
+
+All TFLite details are implemented in the native API. Create an API object by
+using one of the factory functions and get model results by calling functions
+defined in the interface.
+
+#### Sample usage
+
+Here is an example using the C++
+[`BertQuestionAnswerer`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h)
+for
+[MobileBert](https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1).
+
+```cpp
+  char kBertModelPath[] = "path/to/model.tflite";
+  // Create the API from a model file
+  std::unique_ptr<BertQuestionAnswerer> question_answerer =
+      BertQuestionAnswerer::CreateFromFile(kBertModelPath);
+
+  char kContext[] = ...; // context of a question to be answered
+  char kQuestion[] = ...; // question to be answered
+  // ask a question
+  std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
+  // answers[0].text is the best answer
+```
+
+#### Building the API
+
+<div align="center">![native_task_api](images/native_task_api.svg)
+<div align="center">Figure 2. Native Task API
+<div align="left">
+
+To build an API object,you must provide the following information by extending
+[`BaseTaskApi`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/base_task_api.h)
+
+*   __Determine the API I/O__ - Your API should expose similar input/output
+    across different platforms. e.g `BertQuestionAnswerer` takes two strings
+    `(std::string& context, std::string& question)` as input and outputs a
+    vector of possible answer and probabilities as `std::vector<QaAnswer>`. This
+    is done by specifying the corresponding types in `BaseTaskApi`'s
+    [template parameter](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/base_task_api.h?q="template <class OutputType, class... InputTypes>").
+    With the template parameters specified, the
+    [`BaseTaskApi::Infer`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/base_task_api.h?q="Infer\(InputTypes... args\)")
+    function will have the correct input/output types. This function can be
+    directly called by API clients, but it is a good practice to wrap it inside
+    a model-specific function, in this case, `BertQuestionAnswerer::Answer`.
+
+    ```cpp
+    class BertQuestionAnswerer : public BaseTaskApi<
+                                  std::vector<QaAnswer>, // OutputType
+                                  const std::string&, const std::string& // InputTypes
+                                  > {
+      // Model specific function delegating calls to BaseTaskApi::Infer
+      std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) {
+        return Infer(context, question).value();
+      }
+    }
+    ```
+
+*   __Provide conversion logic between API I/O and input/output tensor of the
+    model__ - With input and output types specified, the subclasses also need to
+    implement the typed functions
+    [`BaseTaskApi::Preprocess`](https://github.com/tensorflow/tflite-support/blob/5cea306040c40b06d6e0ed4e5baf6c307db7bd00/tensorflow_lite_support/cc/task/core/base_task_api.h#L74)
+    and
+    [`BaseTaskApi::Postprocess`](https://github.com/tensorflow/tflite-support/blob/5cea306040c40b06d6e0ed4e5baf6c307db7bd00/tensorflow_lite_support/cc/task/core/base_task_api.h#L80).
+    The two functions provide
+    [inputs](https://github.com/tensorflow/tensorflow/blob/1b84e5af78f85b8d3c4687b7dee65b78113f81cc/tensorflow/lite/schema/schema.fbs#L1007)
+    and
+    [outputs](https://github.com/tensorflow/tensorflow/blob/1b84e5af78f85b8d3c4687b7dee65b78113f81cc/tensorflow/lite/schema/schema.fbs#L1008)
+    from the TFLite `FlatBuffer`. The subclass is responsible for assigning
+    values from the API I/O to I/O tensors. See the complete implementation
+    example in
+    [`BertQuestionAnswerer`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc).
+
+    ```cpp
+    class BertQuestionAnswerer : public BaseTaskApi<
+                                  std::vector<QaAnswer>, // OutputType
+                                  const std::string&, const std::string& // InputTypes
+                                  > {
+      // Convert API input into into tensors
+      absl::Status BertQuestionAnswerer::Preprocess(
+        const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model
+        const std::string& context, const std::string& query // InputType of the API
+      ) {
+        // Perform tokenization on input strings
+        ...
+        // Populate IDs, Masks and SegmentIDs to corresponding input tensors
+        PopulateTensor(input_ids, input_tensors[0]);
+        PopulateTensor(input_mask, input_tensors[1]);
+        PopulateTensor(segment_ids, input_tensors[2]);
+        return absl::OkStatus();
+      }
+
+      // Convert output tensors into API output
+      StatusOr<std::vector<QaAnswer>> // OutputType
+      BertQuestionAnswerer::Postprocess(
+        const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model
+      ) {
+        // Get start/end logits of prediction result from output tensors
+        std::vector<float> end_logits;
+        std::vector<float> start_logits;
+        // output_tensors[0]: end_logits FLOAT[1, 384]
+        PopulateVector(output_tensors[0], &end_logits);
+        // output_tensors[1]: start_logits FLOAT[1, 384]
+        PopulateVector(output_tensors[1], &start_logits);
+        ...
+        std::vector<QaAnswer::Pos> orig_results;
+        // Look up the indices from vocabulary file and build results
+        ...
+        return orig_results;
+      }
+    }
+    ```
+
+*   __Create factory functions of the API__ - A model file and a
+    [`OpResolver`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/core/api/op_resolver.h)
+    are needed to initialize the
+    [`tflite::Interpreter`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/interpreter.h).
+    [`TaskAPIFactory`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/task_api_factory.h)
+    provides utility functions to create BaseTaskApi instances.
+
+    Note: By default
+    [`TaskAPIFactory`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/task_api_factory.h)
+    provides a
+    [`BuiltInOpResolver`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/register.h).
+    If your model needs customized ops or a subset of built-in ops, you can
+    register them by creating a
+    [`MutableOpResolver`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/mutable_op_resolver.h).
+
+    You must also provide any files associated with the model. e.g,
+    `BertQuestionAnswerer` can also have an additional file for its tokenizer's
+    vocabulary.
+
+    ```cpp
+    class BertQuestionAnswerer : public BaseTaskApi<
+                                  std::vector<QaAnswer>, // OutputType
+                                  const std::string&, const std::string& // InputTypes
+                                  > {
+      // Factory function to create the API instance
+      StatusOr<std::unique_ptr<QuestionAnswerer>>
+      BertQuestionAnswerer::CreateBertQuestionAnswerer(
+          const std::string& path_to_model, // model to passed to TaskApiFactory
+          const std::string& path_to_vocab  // additional model specific files
+      ) {
+        // Creates an API object by calling one of the utils from TaskAPIFactory
+        std::unique_ptr<BertQuestionAnswerer> api_to_init;
+        ASSIGN_OR_RETURN(
+            api_to_init,
+            core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
+                path_to_model,
+                absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+                kNumLiteThreads));
+
+        // Perform additional model specific initializations
+        // In this case building a vocabulary vector from the vocab file.
+        api_to_init->InitializeVocab(path_to_vocab);
+        return api_to_init;
+      }
+    }
+    ```
+
+### Android API
+
+Create Android APIs by defining Java/Kotlin interface and delegating the logic
+to the C++ layer through JNI. Android API requires native API to be built first.
+
+#### Sample usage
+
+Here is an example using Java
+[`BertQuestionAnswerer`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java)
+for
+[MobileBert](https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1).
+
+```java
+  String BERT_MODEL_FILE = "path/to/model.tflite";
+  String VOCAB_FILE = "path/to/vocab.txt";
+  // Create the API from a model file and vocabulary file
+    BertQuestionAnswerer bertQuestionAnswerer =
+        BertQuestionAnswerer.createBertQuestionAnswerer(
+            ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);
+
+  String CONTEXT = ...; // context of a question to be answered
+  String QUESTION = ...; // question to be answered
+  // ask a question
+  List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
+  // answers.get(0).text is the best answer
+```
+
+#### Building the API
+
+<div align="center">![android_task_api](images/android_task_api.svg)
+<div align="center">Figure 3. Android Task API
+<div align="left">
+
+Similar to Native APIs, to build an API object, the client needs to provide the
+following information by extending
+[`BaseTaskApi`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java),
+which provides JNI handlings for all Java Task APIs.
+
+*   __Determine the API I/O__ - This usually mirriors the native interfaces. e.g
+    `BertQuestionAnswerer` takes `(String context, String question)` as input
+    and outputs `List<QaAnswer>`. The implementation calls a private native
+    function with similar signature, except it has an additional parameter `long
+    nativeHandle`, which is the pointer returned from C++.
+
+    ```java
+    class BertQuestionAnswerer extends BaseTaskApi {
+      public List<QaAnswer> answer(String context, String question) {
+        return answerNative(getNativeHandle(), context, question);
+      }
+
+      private static native List<QaAnswer> answerNative(
+                                            long nativeHandle, // C++ pointer
+                                            String context, String question // API I/O
+                                           );
+
+    }
+    ```
+
+*   __Create factory functions of the API__ - This also mirrors native factory
+    functions, except Android factory functions also need to take
+    [`Context`](https://developer.android.com/reference/android/content/Context)
+    for file access. The implementation calls one of the utilities in
+    [`TaskJniUtils`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java)
+    to build the corresponding C++ API object and pass its pointer to the
+    `BaseTaskApi` constructor.
+
+    ```java
+      class BertQuestionAnswerer extends BaseTaskApi {
+        private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME =
+                                                  "bert_question_answerer_jni";
+
+        // Extending super constructor by providing the
+        // native handle(pointer of corresponding C++ API object)
+        private BertQuestionAnswerer(long nativeHandle) {
+          super(nativeHandle);
+        }
+
+        public static BertQuestionAnswerer createBertQuestionAnswerer(
+                                            Context context, // Accessing Android files
+                                            String pathToModel, String pathToVocab) {
+          return new BertQuestionAnswerer(
+              // The util first try loads the JNI module with name
+              // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files,
+              // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers
+              // is called with the buffer for a C++ API object pointer
+              TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+                  context,
+                  BertQuestionAnswerer::initJniWithBertByteBuffers,
+                  BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
+                  pathToModel,
+                  pathToVocab));
+        }
+
+        // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
+        // returns C++ API object pointer casted to long
+        private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
+
+      }
+    ```
+
+*   __Implement the JNI module for native functions__ - All Java native methods
+    are implemented by calling a corresponding native function from the JNI
+    module. The factory functions would create a native API object and return
+    its pointer as a long type to Java. In later calls to Java API, the long
+    type pointer is passed back to JNI and cast back to the native API object.
+    The native API results are then converted back to Java results.
+
+    For example, this is how
+    [bert_question_answerer_jni](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc)
+    is implemented.
+
+    ```cpp
+      // Implements BertQuestionAnswerer::initJniWithBertByteBuffers
+      extern "C" JNIEXPORT jlong JNICALL
+      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
+          JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
+        // Convert Java ByteBuffer object into a buffer that can be read by native factory functions
+        absl::string_view model =
+            GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
+
+        // Creates the native API object
+        absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
+            BertQuestionAnswerer::CreateFromBuffer(
+                model.data(), model.size());
+        if (status.ok()) {
+          // converts the object pointer to jlong and return to Java.
+          return reinterpret_cast<jlong>(status->release());
+        } else {
+          return kInvalidPointer;
+        }
+      }
+
+      // Implements BertQuestionAnswerer::answerNative
+      extern "C" JNIEXPORT jobject JNICALL
+      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
+      JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) {
+      // Convert long to native API object pointer
+      QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
+
+      // Calls the native API
+      std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context),
+                                             JStringToString(env, question));
+
+      // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>)
+      jclass qa_answer_class =
+        env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
+      jmethodID qa_answer_ctor =
+        env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
+      return ConvertVectorToArrayList<QaAnswer>(
+        env, results,
+        [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
+          jstring text = env->NewStringUTF(ans.text.data());
+          jobject qa_answer =
+              env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
+                             ans.pos.end, ans.pos.logit);
+          env->DeleteLocalRef(text);
+          return qa_answer;
+        });
+      }
+
+      // Implements BaseTaskApi::deinitJni by delete the native object
+      extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni(
+          JNIEnv* env, jobject thiz, jlong native_handle) {
+        delete reinterpret_cast<QuestionAnswerer*>(native_handle);
+      }
+    ```
+
+### iOS API
+
+Create iOS APIs by wrapping a native API object into a ObjC API object. The
+created API object can be used in either ObjC or Swift. iOS API requires the
+native API to be built first.
+
+#### Sample usage
+
+Here is an example using ObjC
+[`TFLBertQuestionAnswerer`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h)
+for [MobileBert](https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1)
+in Swfit.
+
+```swift
+  static let mobileBertModelPath = "path/to/model.tflite";
+  // Create the API from a model file and vocabulary file
+  let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
+      modelPath: mobileBertModelPath)
+
+  static let context = ...; // context of a question to be answered
+  static let question = ...; // question to be answered
+  // ask a question
+  let answers = mobileBertAnswerer.answer(
+      context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
+  // answers.[0].text is the best answer
+```
+
+#### Building the API
+
+<div align="center">![ios_task_api](images/ios_task_api.svg)
+<div align="center">Figure 4. iOS Task API
+<div align="left">
+
+iOS API is a simple ObjC wrapper on top of native API. Build the API by
+following the steps below:
+
+*   __Define the ObjC wrapper__ - Define an ObjC class and delegate the
+    implementations to the corresponding native API object. Note the native
+    dependencies can only appear in a .mm file due to Swift's inability to
+    interop with C++.
+
+    *   .h file
+
+    ```objc
+      @interface TFLBertQuestionAnswerer : NSObject
+
+      // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer
+      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
+                                                    vocabPath:(NSString*)vocabPath
+          NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:));
+
+      // Delegate calls to the native BertQuestionAnswerer::Answer
+      - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
+                                         question:(NSString*)question
+          NS_SWIFT_NAME(answer(context:question:));
+    }
+    ```
+
+    *   .mm file
+
+    ```objc
+      using BertQuestionAnswererCPP = ::tflite::task::text::qa::BertQuestionAnswerer;
+
+      @implementation TFLBertQuestionAnswerer {
+        // define an iVar for the native API object
+        std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
+      }
+
+      // Initilalize the native API object
+      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
+                                              vocabPath:(NSString *)vocabPath {
+        absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
+            BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
+                                                                MakeString(vocabPath));
+        _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
+        return [[TFLBertQuestionAnswerer alloc]
+            initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
+      }
+
+      // Calls the native API and converts C++ results into ObjC results
+      - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
+        std::vector<QaAnswerCPP> results =
+          _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question));
+        return [self arrayFromVector:results];
+      }
+    }
+    ```
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md
new file mode 100644
index 0000000..94d6b51
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md
@@ -0,0 +1,170 @@
+# Integrate image classifiers
+
+Image classification is a common use of machine learning to identify what an
+image represents. For example, we might want to know what type of animal appears
+in a given picture. The task of predicting what an image represents is called
+_image classification_. An image classifier is trained to recognize various
+classes of images. For example, a model might be trained to recognize photos
+representing three different types of animals: rabbits, hamsters, and dogs. See
+the
+[introduction of image classification](../../models/image_classification/overview.md)
+for more information about image classifiers.
+
+Use the Task Library `ImageClassifier` API to deploy your custom image
+classifiers or pretrained ones into your model apps.
+
+## Key features of the ImageClassifier API
+
+*   Input image processing, including rotation, resizing, and color space
+    conversion.
+
+*   Region of interest of the input image.
+
+*   Label map locale.
+
+*   Score threshold to filter results.
+
+*   Top-k classification results.
+
+*   Label allowlist and denylist.
+
+## Supported image classifier models
+
+The following models are guaranteed to be compatible with the `ImageClassifier`
+API.
+
+*   Models created by
+    [TensorFlow Lite Model Maker for Image Classfication](https://www.tensorflow.org/lite/tutorials/model_maker_image_classification).
+
+*   The
+    [pretrained image classification models from TensorFlow Lite Hosted Models](https://www.tensorflow.org/lite/guide/hosted_models#image_classification).
+
+*   The
+    [pretrained image classification models on TensorFlow Hub](https://tfhub.dev/tensorflow/collections/lite/task-library/image-classifier/1).
+
+*   Models created by
+    [AutoML Vision Edge Image Classification](https://cloud.google.com/vision/automl/docs/edge-quickstart).
+
+*   Custom models that meet the
+    [model compatibility requirements](#model-compatibility-requirements).
+
+## Run inference in Java
+
+### Step 1: Import Gradle dependency and other settings
+
+Copy the `.tflite` model file to the assets directory of the Android module
+where the model will be run. Specify that the file should not be compressed, and
+add the TensorFlow Lite library to the module’s `build.gradle` file:
+
+```java
+android {
+    // Other settings
+
+    // Specify tflite file should not be compressed for the app apk
+    aaptOptions {
+        noCompress "tflite"
+    }
+
+}
+
+dependencies {
+    // Other dependencies
+
+    // Import the Task Vision Library dependency
+    implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
+
+}
+```
+
+### Step 2: Using the model
+
+```java
+// Initialization
+ImageClassifierOptions options = ImageClassifierOptions.builder().setMaxResults(1).build();
+ImageClassifier imageClassifier = ImageClassifier.createFromFileAndOptions(context, modelFile, options);
+
+// Run inference
+List<Classifications> results = imageClassifier.classify(image);
+```
+
+See the
+[source code and javadoc](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java)
+for more options to configure `ImageClassifier`.
+
+## Run inference in C++
+
+Note: we are working on improving the usability of the C++ Task Library, such as
+providing prebuilt binaries and creating user-friendly workflows to build from
+source code. The C++ API may be subject to change.
+
+```c++
+// Initialization
+ImageClassifierOptions options;
+options.mutable_model_file_with_metadata()->set_file_name(model_file);
+std::unique_ptr<ImageClassifier> image_classifier = ImageClassifier::CreateFromOptions(options).value();
+
+// Run inference
+const ClassificationResult result = image_classifier->Classify(*frame_buffer).value();
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/vision/image_classifier.h)
+for more options to configure `ImageClassifier`.
+
+## Example results
+
+Here is an example of the classification results of a
+[bird classifier](https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3).
+
+<img src="images/sparrow.jpg" alt="sparrow" width="50%">
+
+```
+Results:
+  Rank #0:
+   index       : 671
+   score       : 0.91406
+   class name  : /m/01bwb9
+   display name: Passer domesticus
+  Rank #1:
+   index       : 670
+   score       : 0.00391
+   class name  : /m/01bwbt
+   display name: Passer montanus
+  Rank #2:
+   index       : 495
+   score       : 0.00391
+   class name  : /m/0bwm6m
+   display name: Passer italiae
+```
+
+Try out the simple
+[CLI demo tool for ImageClassifier](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/vision/desktop#image-classifier)
+with your own model and test data.
+
+## Model compatibility requirements
+
+The `ImageClassifier` API expects a TFLite model with mandatory
+[TFLite Model Metadata](../../convert/metadata.md).
+
+The compatible image classifier models should meet the following requirements:
+
+*   Input image tensor (kTfLiteUInt8/kTfLiteFloat32)
+
+    -   image input of size `[batch x height x width x channels]`.
+    -   batch inference is not supported (`batch` is required to be 1).
+    -   only RGB inputs are supported (`channels` is required to be 3).
+    -   if type is kTfLiteFloat32, NormalizationOptions are required to be
+        attached to the metadata for input normalization.
+
+*   Output score tensor (kTfLiteUInt8/kTfLiteFloat32)
+
+    -   with `N` classes and either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1
+        x 1 x N]`
+    -   optional (but recommended) label map(s) as AssociatedFile-s with type
+        TENSOR_AXIS_LABELS, containing one label per line. The first such
+        AssociatedFile (if any) is used to fill the `label` field (named as
+        `class_name` in C++) of the results. The `display_name` field is filled
+        from the AssociatedFile (if any) whose locale matches the
+        `display_names_locale` field of the `ImageClassifierOptions` used at
+        creation time ("en" by default, i.e. English). If none of these are
+        available, only the `index` field of the results will be filled.
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_segmenter.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_segmenter.md
new file mode 100644
index 0000000..1239ce1
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_segmenter.md
@@ -0,0 +1,162 @@
+# Integrate image segmenters
+
+Image segmenters predict whether each pixel of an image is associated with a
+certain class. This is in contrast to
+<a href="../../models/object_detection/overview.md">object detection</a>, which
+detects objects in rectangular regions, and
+<a href="../../models/image_classification/overview.md">image
+classification</a>, which classifies the overall image. See the
+[introduction of image segmentation](../../models/segmentation/overview.md) for
+more information about image segmenters.
+
+Use the Task Library `ImageSegmenter` API to deploy your custom image segmenters
+or pretrained ones into your model apps.
+
+## Key features of the ImageSegmenter API
+
+*   Input image processing, including rotation, resizing, and color space
+    conversion.
+
+*   Label map locale.
+
+*   Two output types, category mask and confidence masks.
+
+*   Colored label for display purpose.
+
+## Supported image segmenter models
+
+The following models are guaranteed to be compatible with the `ImageSegmenter`
+API.
+
+*   The
+    [pretrained image segmentation models on TensorFlow Hub](https://tfhub.dev/tensorflow/collections/lite/task-library/image-segmenter/1).
+
+*   Custom models that meet the
+    [model compatibility requirements](#model-compatibility-requirements).
+
+## Run inference in Java
+
+### Step 1: Import Gradle dependency and other settings
+
+Copy the `.tflite` model file to the assets directory of the Android module
+where the model will be run. Specify that the file should not be compressed, and
+add the TensorFlow Lite library to the module’s `build.gradle` file:
+
+```java
+android {
+    // Other settings
+
+    // Specify tflite file should not be compressed for the app apk
+    aaptOptions {
+        noCompress "tflite"
+    }
+
+}
+
+dependencies {
+    // Other dependencies
+
+    // Import the Task Vision Library dependency
+    implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
+}
+```
+
+### Step 2: Using the model
+
+```java
+// Initialization
+ImageSegmenterOptions options = ImageSegmenterOptions.builder().setOutputType(OutputType.CONFIDENCE_MASK).build();
+ImageSegmenter imageSegmenter = ImageSegmenter.createFromFileAndOptions(context, modelFile, options);
+
+// Run inference
+List<Segmentation> results = imageSegmenter.segment(image);
+```
+
+See the
+[source code and javadoc](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java)
+for more options to configure `ImageSegmenter`.
+
+## Run inference in C++
+
+Note: we are working on improving the usability of the C++ Task Library, such as
+providing prebuilt binaries and creating user-friendly workflows to build from
+source code. The C++ API may be subject to change.
+
+```c++
+// Initialization
+ImageSegmenterOptions options;
+options.mutable_model_file_with_metadata()->set_file_name(model_file);
+std::unique_ptr<ImageSegmenter> image_segmenter = ImageSegmenter::CreateFromOptions(options).value();
+
+// Run inference
+const SegmentationResult result = image_segmenter->Segment(*frame_buffer).value();
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/vision/image_segmenter.h)
+for more options to configure `ImageSegmenter`.
+
+## Example results
+
+Here is an example of the segmentation results of
+[deeplab_v3](https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1), a
+generic segmentation model available on TensorFlow Hub.
+
+<img src="images/plane.jpg" alt="plane" width="50%">
+
+```
+Color Legend:
+ (r: 000, g: 000, b: 000):
+  index       : 0
+  class name  : background
+ (r: 128, g: 000, b: 000):
+  index       : 1
+  class name  : aeroplane
+
+# (omitting multiple lines for conciseness) ...
+
+ (r: 128, g: 192, b: 000):
+  index       : 19
+  class name  : train
+ (r: 000, g: 064, b: 128):
+  index       : 20
+  class name  : tv
+Tip: use a color picker on the output PNG file to inspect the output mask with
+this legend.
+```
+
+The segmentation category mask should looks like:
+
+<img src="images/segmentation-output.png" alt="segmentation-output" width="30%">
+
+Try out the simple
+[CLI demo tool for ImageClassifier](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/vision/desktop#image-segmenter)
+with your own model and test data.
+
+## Model compatibility requirements
+
+The `ImageSegmenter` API expects a TFLite model with mandatory
+[TFLite Model Metadata](../../convert/metadata.md).
+
+*   Input image tensor (kTfLiteUInt8/kTfLiteFloat32)
+
+    -   image input of size `[batch x height x width x channels]`.
+    -   batch inference is not supported (`batch` is required to be 1).
+    -   only RGB inputs are supported (`channels` is required to be 3).
+    -   if type is kTfLiteFloat32, NormalizationOptions are required to be
+        attached to the metadata for input normalization.
+
+*   Output masks tensor: (kTfLiteUInt8/kTfLiteFloat32)
+
+    -   tensor of size `[batch x mask_height x mask_width x num_classes]`, where
+        `batch` is required to be 1, `mask_width` and `mask_height` are the
+        dimensions of the segmentation masks produced by the model, and
+        `num_classes` is the number of classes supported by the model.
+    -   optional (but recommended) label map(s) can be attached as
+        AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per
+        line. The first such AssociatedFile (if any) is used to fill the `label`
+        field (named as `class_name` in C++) of the results. The `display_name`
+        field is filled from the AssociatedFile (if any) whose locale matches
+        the `display_names_locale` field of the `ImageSegmenterOptions` used at
+        creation time ("en" by default, i.e. English). If none of these are
+        available, only the `index` field of the results will be filled.
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/android_task_api.svg b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/android_task_api.svg
new file mode 100644
index 0000000..c9554b4
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/android_task_api.svg
@@ -0,0 +1 @@
+<?xml version="1.0" encoding="utf-8" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN" "http://www.w3.org/TR/2001/REC-SVG-20010904/DTD/svg10.dtd"><svg xmlns="http://www.w3.org/2000/svg" width="491" height="594" xmlns:xlink="http://www.w3.org/1999/xlink"><desc style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Created with Raphaël 2.2.0</desc><defs style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><marker id="raphael-marker-endblock55-objjks7y" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-obj68k5o" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-objev9wf" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-objvt0oa" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker></defs><rect x="10" y="10" width="163.625" height="28" rx="0" ry="0" fill="#0000ff" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="15" y="15" width="153.625" height="18" rx="0" ry="0" fill="#0000ff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="91.8125" y="24" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Android Task API</tspan></text><rect x="10" y="48" width="154.4375" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="20" y="67.609375" width="134.4375" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="87.21875" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Android client</tspan></text><rect x="10" y="517.3125" width="154.4375" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="20" y="536.921875" width="134.4375" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="87.21875" y="545.921875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Android client</tspan></text><path fill="none" stroke="#000000" d="M87.21875,105.21875L87.21875,517.3125" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="184.4375" y="48" width="106.4375" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="194.4375" y="58" width="86.4375" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="237.65625" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Java </tspan><tspan dy="19.2" x="237.65625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> Interface</tspan></text><rect x="184.4375" y="517.3125" width="106.4375" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="194.4375" y="527.3125" width="86.4375" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="237.65625" y="545.921875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Java </tspan><tspan dy="19.2" x="237.65625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> Interface</tspan></text><path fill="none" stroke="#000000" d="M237.65625,105.21875L237.65625,517.3125" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="310.875" y="48" width="87.21875" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="320.875" y="58" width="67.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="354.484375" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="354.484375" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API</tspan></text><rect x="310.875" y="517.3125" width="87.21875" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="320.875" y="527.3125" width="67.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="354.484375" y="545.921875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="354.484375" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API</tspan></text><path fill="none" stroke="#000000" d="M354.484375,105.21875L354.484375,517.3125" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="104.828125" y="111.609375" width="115.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="162.4375" y="130.21875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#00acc1" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Java/Kotlin </tspan><tspan dy="19.2" x="162.4375" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">API input</tspan></text><path fill="none" stroke="#00acc1" d="M87.21875,162.4375C87.21875,162.4375,206.54771423339844,162.4375,232.65092515945435,162.4375" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objjks7y)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="218.25" y="182.4375" width="38.8125" height="28" rx="0" ry="0" fill="#00acc1" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="223.25" y="187.4375" width="28.8125" height="18" rx="0" ry="0" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="237.65625" y="196.4375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">JNI</tspan></text><rect x="252.859375" y="216.828125" width="86.421875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="296.0703125" y="235.4375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#00acc1" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="296.0703125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API input</tspan></text><path fill="none" stroke="#00acc1" d="M237.65625,267.65625C237.65625,267.65625,327.1527045266703,267.65625,349.483118717651,267.65625" stroke-width="2" marker-end="url(#raphael-marker-endblock55-obj68k5o)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="301.46875" y="287.65625" width="106.03125" height="47.21875" rx="0" ry="0" fill="#ff6f00" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="306.46875" y="292.65625" width="96.03125" height="37.21875" rx="0" ry="0" fill="#ff6f00" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="354.484375" y="311.265625" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">model </tspan><tspan dy="19.2" x="354.484375" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> invocation</tspan></text><rect x="248.0625" y="341.265625" width="96.015625" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="296.0703125" y="359.875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#d4e157" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="296.0703125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API output</tspan></text><path fill="none" stroke="#d4e157" d="M354.484375,392.09375C354.484375,392.09375,264.9879204733297,392.09375,242.657506282349,392.09375" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objev9wf)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="218.25" y="412.09375" width="38.8125" height="28" rx="0" ry="0" fill="#d4e157" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="223.25" y="417.09375" width="28.8125" height="18" rx="0" ry="0" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="237.65625" y="426.09375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">JNI</tspan></text><rect x="104.828125" y="446.484375" width="115.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="162.4375" y="465.09375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#d4e157" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Java/Kotlin </tspan><tspan dy="19.2" x="162.4375" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API output</tspan></text><path fill="none" stroke="#d4e157" d="M237.65625,497.3125C237.65625,497.3125,118.32728576660156,497.3125,92.22407484054565,497.3125" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objvt0oa)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg>
\ No newline at end of file
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/detection-output.png b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/detection-output.png
new file mode 100644
index 0000000..c8d56f4
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/detection-output.png
Binary files differ
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/dogs.jpg b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/dogs.jpg
new file mode 100644
index 0000000..9db4bee
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/dogs.jpg
Binary files differ
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/ios_task_api.svg b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/ios_task_api.svg
new file mode 100644
index 0000000..615b123
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/ios_task_api.svg
@@ -0,0 +1 @@
+<?xml version="1.0" encoding="utf-8" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN" "http://www.w3.org/TR/2001/REC-SVG-20010904/DTD/svg10.dtd"><svg xmlns="http://www.w3.org/2000/svg" width="452" height="632" xmlns:xlink="http://www.w3.org/1999/xlink"><desc style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Created with Raphaël 2.2.0</desc><defs style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><marker id="raphael-marker-endblock55-objk5zfc" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-obj0wlxn" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-objqcufj" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-objy4vsc" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker></defs><rect x="10" y="10" width="125.21875" height="28" rx="0" ry="0" fill="#0000ff" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="15" y="15" width="115.21875" height="18" rx="0" ry="0" fill="#0000ff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="72.609375" y="24" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">iOS Task API</tspan></text><rect x="10" y="48" width="116.03125" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="20" y="67.609375" width="96.03125" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="68.015625" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">iOS client</tspan></text><rect x="10" y="555.75" width="116.03125" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="20" y="575.359375" width="96.03125" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="68.015625" y="584.359375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">iOS client</tspan></text><path fill="none" stroke="#000000" d="M68.015625,105.21875L68.015625,555.75" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="146.03125" y="48" width="106.4375" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="156.03125" y="58" width="86.4375" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="199.25" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">ObjC </tspan><tspan dy="19.2" x="199.25" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> Interface</tspan></text><rect x="146.03125" y="555.75" width="106.4375" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="156.03125" y="565.75" width="86.4375" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="199.25" y="584.359375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">ObjC </tspan><tspan dy="19.2" x="199.25" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> Interface</tspan></text><path fill="none" stroke="#000000" d="M199.25,105.21875L199.25,555.75" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="272.46875" y="48" width="87.21875" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="282.46875" y="58" width="67.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="316.078125" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="316.078125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API</tspan></text><rect x="272.46875" y="555.75" width="87.21875" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="282.46875" y="565.75" width="67.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="316.078125" y="584.359375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="316.078125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API</tspan></text><path fill="none" stroke="#000000" d="M316.078125,105.21875L316.078125,555.75" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="80.8125" y="111.609375" width="105.640625" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="133.6328125" y="130.21875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#00acc1" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Swift/ObjC </tspan><tspan dy="19.2" x="133.6328125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">API input</tspan></text><path fill="none" stroke="#00acc1" d="M68.015625,162.4375C68.015625,162.4375,170.23761106934398,162.4375,194.2458021334819,162.4375" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objk5zfc)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="160.640625" y="182.4375" width="77.21875" height="47.21875" rx="0" ry="0" fill="#00acc1" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="165.640625" y="187.4375" width="67.21875" height="37.21875" rx="0" ry="0" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="199.25" y="206.046875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="199.25" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> wrapper</tspan></text><rect x="214.453125" y="236.046875" width="86.421875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="257.6640625" y="254.65625" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#00acc1" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="257.6640625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API input</tspan></text><path fill="none" stroke="#00acc1" d="M199.25,286.875C199.25,286.875,288.7464545266703,286.875,311.076868717651,286.875" stroke-width="2" marker-end="url(#raphael-marker-endblock55-obj0wlxn)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="263.0625" y="306.875" width="106.03125" height="47.21875" rx="0" ry="0" fill="#ff6f00" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="268.0625" y="311.875" width="96.03125" height="37.21875" rx="0" ry="0" fill="#ff6f00" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="316.078125" y="330.484375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">model </tspan><tspan dy="19.2" x="316.078125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> invocation</tspan></text><rect x="209.65625" y="360.484375" width="96.015625" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="257.6640625" y="379.09375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#d4e157" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="257.6640625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API output</tspan></text><path fill="none" stroke="#d4e157" d="M316.078125,411.3125C316.078125,411.3125,226.58167047332972,411.3125,204.251256282349,411.3125" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objqcufj)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="160.640625" y="431.3125" width="77.21875" height="47.21875" rx="0" ry="0" fill="#d4e157" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="165.640625" y="436.3125" width="67.21875" height="37.21875" rx="0" ry="0" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="199.25" y="454.921875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="199.25" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> wrapper</tspan></text><rect x="80.8125" y="484.921875" width="105.640625" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="133.6328125" y="503.53125" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#d4e157" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Swift/ObjC </tspan><tspan dy="19.2" x="133.6328125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> API output</tspan></text><path fill="none" stroke="#d4e157" d="M199.25,535.75C199.25,535.75,97.02801393065602,535.75,73.0198228665181,535.75" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objy4vsc)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg>
\ No newline at end of file
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/native_task_api.svg b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/native_task_api.svg
new file mode 100644
index 0000000..e87c95a
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/native_task_api.svg
@@ -0,0 +1 @@
+<?xml version="1.0" encoding="utf-8" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN" "http://www.w3.org/TR/2001/REC-SVG-20010904/DTD/svg10.dtd"><svg xmlns="http://www.w3.org/2000/svg" width="659" height="556" xmlns:xlink="http://www.w3.org/1999/xlink"><desc style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">Created with Raphaël 2.2.0</desc><defs style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><marker id="raphael-marker-endblock55-obji09o4" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-obj7hc3v" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#ff6f00" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-objuf9tb" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#ff6f00" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker><marker id="raphael-marker-endblock55-objokijf" markerHeight="5" markerWidth="5" orient="auto" refX="2.5" refY="2.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"><use xlink:href="#raphael-marker-block" transform="rotate(180 2.5 2.5) scale(1,1)" stroke-width="1.0000" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></use></marker></defs><rect x="10" y="10" width="154.03125" height="28" rx="0" ry="0" fill="#0000ff" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="15" y="15" width="144.03125" height="18" rx="0" ry="0" fill="#0000ff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="87.015625" y="24" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native Task API</tspan></text><rect x="10" y="48" width="116.03125" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="20" y="67.609375" width="96.03125" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="68.015625" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">C++ client</tspan></text><rect x="10" y="478.875" width="116.03125" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="20" y="498.484375" width="96.03125" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="68.015625" y="507.484375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">C++ client</tspan></text><path fill="none" stroke="#000000" d="M68.015625,105.21875L68.015625,478.875" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="146.03125" y="48" width="116.03125" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="156.03125" y="67.609375" width="96.03125" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="204.046875" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">preprocess</tspan></text><rect x="146.03125" y="478.875" width="116.03125" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="156.03125" y="498.484375" width="96.03125" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="204.046875" y="507.484375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">preprocess</tspan></text><path fill="none" stroke="#000000" d="M204.046875,105.21875L204.046875,478.875" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="295.65625" y="48" width="87.21875" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="305.65625" y="58" width="67.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="339.265625" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">TFLite </tspan><tspan dy="19.2" x="339.265625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> runtime</tspan></text><rect x="295.65625" y="478.875" width="87.21875" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="305.65625" y="488.875" width="67.21875" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="339.265625" y="507.484375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">TFLite </tspan><tspan dy="19.2" x="339.265625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> runtime</tspan></text><path fill="none" stroke="#000000" d="M339.265625,105.21875L339.265625,478.875" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="421.296875" y="48" width="125.625" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="431.296875" y="67.609375" width="105.625" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="484.109375" y="76.609375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">postprocess</tspan></text><rect x="421.296875" y="478.875" width="125.625" height="57.21875" rx="0" ry="0" fill="none" stroke="#000000" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="431.296875" y="498.484375" width="105.625" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="484.109375" y="507.484375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#000000" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">postprocess</tspan></text><path fill="none" stroke="#000000" d="M484.109375,105.21875L484.109375,478.875" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="92.8125" y="111.609375" width="86.4375" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="136.03125" y="130.21875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#00acc1" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="136.03125" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">API input</tspan></text><path fill="none" stroke="#00acc1" d="M68.015625,162.4375C68.015625,162.4375,174.50227653980255,162.4375,199.04412201186642,162.4375" stroke-width="2" marker-end="url(#raphael-marker-endblock55-obji09o4)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="131.828125" y="182.4375" width="144.4375" height="28" rx="0" ry="0" fill="#00acc1" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="136.828125" y="187.4375" width="134.4375" height="18" rx="0" ry="0" fill="#00acc1" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="204.046875" y="196.4375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">data to tensor</tspan></text><rect x="214.046875" y="226.4375" width="115.21875" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="271.65625" y="235.4375" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ff6f00" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">input tensor</tspan></text><path fill="none" stroke="#ff6f00" d="M204.046875,248.4375C204.046875,248.4375,309.8390848468989,248.4375,334.2725395625157,248.4375" stroke-width="2" marker-end="url(#raphael-marker-endblock55-obj7hc3v)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="286.25" y="268.4375" width="106.03125" height="47.21875" rx="0" ry="0" fill="#ff6f00" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="291.25" y="273.4375" width="96.03125" height="37.21875" rx="0" ry="0" fill="#ff6f00" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="339.265625" y="292.046875" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">model </tspan><tspan dy="19.2" x="339.265625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"> invocation</tspan></text><rect x="349.265625" y="331.65625" width="124.84375" height="18" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="411.6875" y="340.65625" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ff6f00" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">output tensor</tspan></text><path fill="none" stroke="#ff6f00" d="M339.265625,353.65625C339.265625,353.65625,453.5913529600948,353.65625,479.1025139355652,353.65625" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objuf9tb)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path><rect x="411.890625" y="373.65625" width="144.4375" height="28" rx="0" ry="0" fill="#d4e157" stroke="#ffffff" stroke-width="2" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><rect x="416.890625" y="378.65625" width="134.4375" height="18" rx="0" ry="0" fill="#d4e157" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="484.109375" y="387.65625" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#ffffff" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="5.5" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">tensor to data</tspan></text><rect x="228.046875" y="408.046875" width="96.03125" height="37.21875" rx="0" ry="0" fill="#ffffff" stroke="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></rect><text x="276.0625" y="426.65625" text-anchor="middle" font-family="Andale Mono, monospace" font-size="16px" stroke="none" fill="#d4e157" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0); text-anchor: middle; font-family: &quot;Andale Mono&quot;, monospace; font-size: 16px;"><tspan dy="-4.1015625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">native </tspan><tspan dy="19.2" x="276.0625" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">API output</tspan></text><path fill="none" stroke="#d4e157" d="M484.109375,458.875C484.109375,458.875,120.11422207392752,458.875,73.01523988378631,458.875" stroke-width="2" marker-end="url(#raphael-marker-endblock55-objokijf)" stroke-dasharray="none" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg>
\ No newline at end of file
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/plane.jpg b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/plane.jpg
new file mode 100644
index 0000000..0edefa4
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/plane.jpg
Binary files differ
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/prebuilt_task_apis.svg b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/prebuilt_task_apis.svg
new file mode 100644
index 0000000..c9aced3
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/prebuilt_task_apis.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 560.3700787401575 156.24409448818898" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l560.37006 0l0 156.2441l-560.37006 0l0 -156.2441z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l560.37006 0l0 156.2441l-560.37006 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m143.04504 24.922682l8.283463 0l0 3.8740158l-8.283463 0z" fill-rule="evenodd"/><path fill="#ff9900" d="m147.16116 8.367561l0 0c0 -2.968771 2.406662 -5.3754354 5.3754425 -5.3754354l78.225494 0c1.4256592 0 2.792923 0.566339 3.8010101 1.5744286c1.0080872 1.00809 1.5744324 2.3753529 1.5744324 3.8010068l0 21.501097c0 2.968771 -2.406662 5.375437 -5.3754425 5.375437l-78.225494 0c-2.9687805 0 -5.3754425 -2.4066658 -5.3754425 -5.375437z" fill-rule="evenodd"/><path fill="#000000" d="m161.50356 15.764359l3.0 0q2.296875 0 2.296875 1.781251q0 0.984375 -0.984375 1.5q0.640625 0.15625 0.984375 0.59375q0.359375 0.421875 0.359375 1.09375q0 1.03125 -0.703125 1.578125q-0.6875 0.546875 -2.015625 0.546875l-3.0 0l0 -0.859375l0.359375 -0.03125q0.296875 -0.03125 0.296875 -0.296875l0 -4.96875l-0.59375 -0.03125l0 -0.90625095zm1.90625 3.859376l0 2.234375l1.0 0q1.375 0 1.375 -1.125q0 -0.546875 -0.34375 -0.828125q-0.34375 -0.28125 -0.984375 -0.28125l-1.046875 0zm0 -2.875l0 1.9375l0.78125 0q0.65625 0 0.96875 -0.265625q0.328125 -0.28125 0.328125 -0.765625q0 -0.5 -0.328125 -0.703125q-0.3125 -0.203125 -0.890625 -0.203125l-0.859375 0zm9.227631 0.96875l0 3.953125q0 0.171875 0.046875 0.234375q0.0625 0.0625 0.21875 0.078125l0.34375 0.015625l0 0.859375l-1.703125 0l0 -0.625l-0.03125 0q-0.53125 0.765625 -1.453125 0.765625q-1.09375 0 -1.625 -0.6875q-0.515625 -0.703125 -0.515625 -1.90625q0 -1.453125 0.703125 -2.25q0.703125 -0.8125 2.109375 -0.8125q0.90625 0 1.90625 0.375zm-1.234375 3.453125l0 -2.765625q-0.296875 -0.140625 -0.828125 -0.140625q-0.71875 0 -1.046875 0.578125q-0.3125 0.578125 -0.3125 1.515625q0 1.734375 1.109375 1.734375q0.46875 0 0.765625 -0.28125q0.3125 -0.28125 0.3125 -0.640625zm5.276535 -2.859375q-0.25 -0.09375 -0.625 -0.09375q-0.359375 0 -0.578125 0.171875q-0.21875 0.15625 -0.21875 0.390625q0 0.234375 0.078125 0.375q0.09375 0.140625 0.265625 0.234375q0.265625 0.140625 0.625 0.234375q0.375 0.09375 0.5625 0.15625q0.1875 0.0625 0.453125 0.203125q0.265625 0.140625 0.40625 0.296875q0.375 0.390625 0.375 1.015625q0 0.796875 -0.578125 1.25q-0.578125 0.453125 -1.484375 0.453125q-1.296875 0 -1.953125 -0.328125l0 -1.484375l0.953125 -0.078125l0 0.515625q0 0.46875 0.890625 0.46875q0.90625 0 0.90625 -0.65625q0 -0.234375 -0.15625 -0.375q-0.15625 -0.15625 -0.3125 -0.203125q-0.140625 -0.0625 -0.359375 -0.109375q-0.203125 -0.046875 -0.40625 -0.09375q-0.1875 -0.0625 -0.421875 -0.15625q-0.21875 -0.09375 -0.5 -0.265625q-0.546875 -0.34375 -0.546875 -1.171875q0 -0.828125 0.578125 -1.265625q0.59375 -0.453125 1.484375 -0.453125q0.890625 0 1.765625 0.421875l0 1.28125l-0.953125 0.078125l0 -0.453125q0 -0.265625 -0.25 -0.359375zm4.69664 -0.96875q0.796875 0 1.265625 0.390625q0.46875 0.390625 0.46875 1.09375q0 0.46875 -0.203125 0.828125q-0.203125 0.34375 -0.5 0.546875q-0.296875 0.203125 -0.734375 0.328125q-0.703125 0.21875 -1.609375 0.21875q0.046875 0.5625 0.359375 0.90625q0.3125 0.34375 0.96875 0.34375q0.671875 0 1.328125 -0.46875l0.40625 0.875q-0.203125 0.1875 -0.71875 0.390625q-0.5 0.203125 -1.15625 0.203125q-1.296875 0 -1.90625 -0.71875q-0.609375 -0.71875 -0.609375 -1.96875q0 -1.265625 0.6875 -2.109375q0.703125 -0.859375 1.953125 -0.859375zm-0.484375 2.4375q0.390625 -0.078125 0.71875 -0.3125q0.328125 -0.25 0.328125 -0.578125q0 -0.640625 -0.640625 -0.640625q-0.59375 0 -0.921875 0.46875q-0.3125 0.46875 -0.34375 1.140625q0.46875 -0.015625 0.859375 -0.078125zm4.9389343 1.890625l0 -4.84375l-0.84375 0q-0.234375 0 -0.234375 0.328125l0 0.5l-1.09375 -0.078125l0 -1.812501l5.578125 0l0 1.812501l-1.09375 0.078125l0 -0.5q0 -0.171875 -0.046875 -0.25q-0.046875 -0.078125 -0.25 -0.078125l-0.734375 0l0 5.09375l0.84375 0.03125l0 0.90625l-3.0 0l0 -0.859375l0.5625 -0.03125q0.3125 -0.03125 0.3125 -0.296875zm8.70546 -3.953125l0 3.953125q0 0.171875 0.046875 0.234375q0.0625 0.0625 0.21875 0.078125l0.34375 0.015625l0 0.859375l-1.703125 0l0 -0.625l-0.03125 0q-0.53125 0.765625 -1.453125 0.765625q-1.09375 0 -1.625 -0.6875q-0.515625 -0.703125 -0.515625 -1.90625q0 -1.453125 0.703125 -2.25q0.703125 -0.8125 2.109375 -0.8125q0.90625 0 1.90625 0.375zm-1.234375 3.453125l0 -2.765625q-0.296875 -0.140625 -0.828125 -0.140625q-0.71875 0 -1.046875 0.578125q-0.3125 0.578125 -0.3125 1.515625q0 1.734375 1.109375 1.734375q0.46875 0 0.765625 -0.28125q0.3125 -0.28125 0.3125 -0.640625zm5.276535 -2.859375q-0.25 -0.09375 -0.625 -0.09375q-0.359375 0 -0.578125 0.171875q-0.21875 0.15625 -0.21875 0.390625q0 0.234375 0.078125 0.375q0.09375 0.140625 0.265625 0.234375q0.265625 0.140625 0.625 0.234375q0.375 0.09375 0.5625 0.15625q0.1875 0.0625 0.453125 0.203125q0.265625 0.140625 0.40625 0.296875q0.375 0.390625 0.375 1.015625q0 0.796875 -0.578125 1.25q-0.578125 0.453125 -1.484375 0.453125q-1.296875 0 -1.953125 -0.328125l0 -1.484375l0.953125 -0.078125l0 0.515625q0 0.46875 0.890625 0.46875q0.90625 0 0.90625 -0.65625q0 -0.234375 -0.15625 -0.375q-0.15625 -0.15625 -0.3125 -0.203125q-0.140625 -0.0625 -0.359375 -0.109375q-0.203125 -0.046875 -0.40625 -0.09375q-0.1875 -0.0625 -0.421875 -0.15625q-0.21875 -0.09375 -0.5 -0.265625q-0.546875 -0.34375 -0.546875 -1.171875q0 -0.828125 0.578125 -1.265625q0.59375 -0.453125 1.484375 -0.453125q0.890625 0 1.765625 0.421875l0 1.28125l-0.953125 0.078125l0 -0.453125q0 -0.265625 -0.25 -0.359375zm2.681015 3.359375l0 -5.203125q0 -0.15625 -0.0625 -0.21875q-0.0625 -0.078125 -0.203125 -0.078125l-0.375 -0.03125l0 -0.87500095l1.890625 0l0 4.609376q0.59375 -0.046875 1.03125 -0.359375q0.4375 -0.3125 0.59375 -0.8125q0.03125 -0.109375 0.03125 -0.171875q0 -0.171875 -0.203125 -0.171875l-0.40625 -0.03125l0 -0.828125l2.203125 0l0 0.859375l-0.5 0.03125q-0.140625 1.078125 -0.9375 1.703125l1.03125 1.875l0.65625 0.03125l0 0.859375l-1.5625 0l-1.203125 -2.21875q-0.359375 0.09375 -0.734375 0.15625l0 1.171875l0.015625 0l0.609375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.34375 -0.03125q0.15625 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625zm5.628067 0l1.71875 -4.96875l-0.671875 -0.0625l0 -0.87500095l2.4375 0l2.1875 6.156251l0 0l0.60939026 0.03125l0 0.90625l-2.7500153 0l0 -0.859375l0.40625 -0.03125q0.1875 -0.03125 0.25 -0.09375q0.0625 -0.0625 0 -0.234375l-0.21875 -0.65625l-2.5 0l-0.3125 0.9375l0.640625 0.03125l0 0.90625l-2.46875 0l0 -0.859375l0.359375 -0.03125q0.21875 -0.03125 0.3125 -0.296875zm2.71875 -4.703125l-0.984375 3.03125l1.984375 0l-0.96875 -3.03125l-0.03125 0zm4.3568115 6.90625l0 -5.15625q0 -0.1875 -0.0625 -0.25q-0.0625 -0.078125 -0.203125 -0.078125l-0.453125 -0.03125l0 -0.859375l1.796875 0l0 0.6875q0.171875 -0.3125 0.609375 -0.578125q0.453125 -0.265625 1.078125 -0.265625q1.921875 0 1.921875 2.703125q0 1.4375 -0.625 2.203125q-0.625 0.765625 -1.703125 0.765625q-0.640625 0 -1.140625 -0.3125l0 1.484375l0.859375 0.03125l0 0.84375l-2.75 0l0 -0.8125l0.359375 -0.03125q0.15625 -0.015625 0.234375 -0.09375q0.078125 -0.0625 0.078125 -0.25zm3.375 -3.71875q0 -1.78125 -1.078125 -1.78125q-0.453125 0 -0.765625 0.28125q-0.3125 0.265625 -0.3125 0.640625l0 2.4375q0.40625 0.265625 0.984375 0.265625q0.578125 0 0.875 -0.53125q0.296875 -0.53125 0.296875 -1.3125zm4.0059357 -2.65625l0 4.46875l0.609375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.34375 -0.03125q0.3125 -0.03125 0.3125 -0.34375l0 -2.953125q0 -0.1875 -0.0625 -0.25q-0.0625 -0.0625 -0.203125 -0.0625l-0.375 -0.015625l0 -0.890625l1.90625 0zm-1.34375 -0.921875q-0.21875 -0.21875 -0.21875 -0.5625q0 -0.34375095 0.21875 -0.56250095q0.21875 -0.234375 0.578125 -0.234375q0.375 0 0.59375 0.234375q0.234375 0.21875 0.234375 0.56250095q0 0.34375 -0.234375 0.5625q-0.21875 0.21875 -0.59375 0.21875q-0.359375 0 -0.578125 -0.21875z" fill-rule="nonzero"/><path fill="#ff9900" d="m126.44907 65.90535l0 0c0 -2.968769 2.4066696 -5.375435 5.375435 -5.375435l119.64284 0c1.4256439 0 2.7929077 0.5663376 3.8009949 1.5744286c1.0080872 1.008091 1.5744324 2.3753548 1.5744324 3.8010063l0 21.501099c0 2.968773 -2.406662 5.375435 -5.3754272 5.375435l-119.64284 0c-2.9687653 0 -5.375435 -2.406662 -5.375435 -5.375435z" fill-rule="evenodd"/><path fill="#000000" d="m148.8283 77.207146q0 1.546875 -0.65625 2.546875q-0.640625 0.984375 -1.875 1.3125l0 0.03125q0.921875 0.046875 1.765625 0.5625q0.59375 0.375 1.390625 0.375q0.1875 0 0.375 -0.03125l-0.140625 1.15625q-0.171875 0 -0.328125 0l-1.1875 -0.203125q-0.75 -0.25 -1.46875 -0.59375q-0.8125 -0.390625 -1.515625 -0.390625q-0.34375 0 -0.671875 0.078125l0.171875 -0.859375q-2.859375 -0.375 -2.859375 -3.984375q0 -1.109375 0.3125 -1.9375q0.328125 -0.84375 0.875 -1.296875q1.0625 -0.921875 2.484375 -0.921875q1.421875 0 2.375 1.015625q0.953125 1.0 0.953125 3.140625zm-5.5 -0.0625q0 2.96875 2.0625 2.96875q1.109375 0 1.5625 -0.90625q0.390625 -0.796875 0.390625 -1.96875q0 -1.046875 -0.375 -1.921875q-0.1875 -0.484375 -0.609375 -0.75q-0.40625 -0.28125 -1.0625 -0.28125q-0.65625 0 -1.140625 0.4375q-0.46875 0.421875 -0.65625 1.046875q-0.171875 0.609375 -0.171875 1.375zm12.749313 3.078125l0 0.953125l-1.9375 0l0 -0.78125q-0.65625 0.9375 -1.9375 0.9375q-1.859375 0 -1.859375 -2.078125l0 -2.765625q0 -0.328125 -0.3125 -0.34375l-0.375 -0.015625l0 -0.984375l2.109375 0l0 3.796875q0 0.640625 0.171875 0.96875q0.1875 0.3125 0.75 0.3125q0.5625 0 0.90625 -0.34375q0.34375 -0.34375 0.34375 -0.84375l0 -2.53125q0 -0.1875 -0.078125 -0.265625q-0.0625 -0.078125 -0.21875 -0.09375l-0.390625 -0.015625l0 -0.984375l2.109375 0l0 4.703125q0 0.1875 0.0625 0.25q0.0625 0.0625 0.25 0.09375l0.40625 0.03125zm3.7483063 -5.265625q0.90625 0 1.421875 0.4375q0.53125 0.4375 0.53125 1.234375q0 0.53125 -0.234375 0.9375q-0.21875 0.390625 -0.5625 0.625q-0.328125 0.21875 -0.8125 0.375q-0.796875 0.234375 -1.8125 0.234375q0.046875 0.640625 0.40625 1.03125q0.359375 0.390625 1.09375 0.390625q0.75 0 1.5 -0.53125l0.453125 0.96875q-0.25 0.21875 -0.8125 0.453125q-0.5625 0.21875 -1.296875 0.21875q-1.46875 0 -2.15625 -0.8125q-0.6875 -0.8125 -0.6875 -2.21875q0 -1.421875 0.78125 -2.375q0.78125 -0.96875 2.1875 -0.96875zm-0.546875 2.75q0.4375 -0.078125 0.8125 -0.34375q0.375 -0.28125 0.375 -0.65625q0 -0.734375 -0.71875 -0.734375q-0.671875 0 -1.03125 0.546875q-0.359375 0.53125 -0.40625 1.265625q0.53125 0 0.96875 -0.078125zm6.4955444 -1.65625q-0.296875 -0.109375 -0.703125 -0.109375q-0.40625 0 -0.65625 0.1875q-0.25 0.1875 -0.25 0.46875q0 0.265625 0.09375 0.421875q0.09375 0.140625 0.28125 0.25q0.296875 0.15625 0.71875 0.265625q0.421875 0.109375 0.625 0.1875q0.203125 0.0625 0.5 0.21875q0.3125 0.15625 0.46875 0.328125q0.421875 0.453125 0.421875 1.140625q0 0.90625 -0.65625 1.421875q-0.640625 0.5 -1.65625 0.5q-1.46875 0 -2.203125 -0.375l0 -1.671875l1.078125 -0.078125l0 0.578125q0 0.53125 1.0 0.53125q1.015625 0 1.015625 -0.734375q0 -0.265625 -0.171875 -0.4375q-0.171875 -0.171875 -0.34375 -0.21875q-0.171875 -0.0625 -0.40625 -0.125q-0.234375 -0.0625 -0.453125 -0.125q-0.21875 -0.0625 -0.484375 -0.15625q-0.25 -0.109375 -0.5625 -0.296875q-0.609375 -0.390625 -0.609375 -1.3125q0 -0.9375 0.65625 -1.4375q0.65625 -0.515625 1.65625 -0.515625q1.0 0 1.984375 0.484375l0 1.4375l-1.078125 0.078125l0 -0.5q0 -0.296875 -0.265625 -0.40625zm3.479538 -2.1875l0.78125 0l0 1.28125l1.59375 0l-0.125 1.0l-1.46875 0l0 3.15625q0 0.53125 0.1875 0.75q0.1875 0.203125 0.59375 0.203125q0.40625 0 0.8125 -0.25l0.359375 0.921875q-0.609375 0.421875 -1.53125 0.421875q-0.53125 0 -0.90625 -0.140625q-0.375 -0.140625 -0.5625 -0.296875q-0.171875 -0.171875 -0.265625 -0.5q-0.09375 -0.328125 -0.109375 -0.515625q-0.015625 -0.1875 -0.015625 -0.546875l0 -3.203125l-0.84375 0l0.109375 -0.875q0.5625 -0.046875 0.84375 -0.390625q0.296875 -0.359375 0.546875 -1.015625zm5.3521576 1.28125l0 5.03125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.40625 -0.046875q0.34375 -0.03125 0.34375 -0.375l0 -3.34375q0 -0.203125 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.078125l-0.421875 -0.03125l0 -0.984375l2.140625 0zm-1.515625 -1.03125q-0.25 -0.25 -0.25 -0.640625q0 -0.390625 0.25 -0.640625q0.265625 -0.265625 0.671875 -0.265625q0.40625 0 0.65625 0.265625q0.265625 0.25 0.265625 0.640625q0 0.390625 -0.265625 0.640625q-0.25 0.234375 -0.65625 0.234375q-0.40625 0 -0.671875 -0.234375zm5.6614075 6.1875q1.3125 0 1.3125 -2.109375q0 -1.078125 -0.296875 -1.625q-0.296875 -0.546875 -0.984375 -0.546875q-0.6875 0 -1.015625 0.53125q-0.328125 0.515625 -0.328125 1.4375q0 1.6875 0.625 2.125q0.28125 0.1875 0.6875 0.1875zm-2.765625 -2.15625q0 -0.875 0.265625 -1.53125q0.265625 -0.65625 0.71875 -1.0q0.84375 -0.65625 1.875 -0.65625q0.71875 0 1.21875 0.234375q0.5 0.234375 0.78125 0.546875q0.28125 0.296875 0.46875 0.890625q0.203125 0.578125 0.203125 1.359375q0 1.65625 -0.8125 2.515625q-0.796875 0.859375 -2.046875 0.859375q-1.25 0 -1.96875 -0.8125q-0.703125 -0.8125 -0.703125 -2.40625zm6.2735443 -2.046875l0 -0.953125l2.03125 0l0 0.78125q0.296875 -0.453125 0.8125 -0.703125q0.53125 -0.265625 1.125 -0.265625q0.90625 0 1.40625 0.53125q0.5 0.515625 0.5 1.5625l0 3.125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.390625 -0.046875q0.1875 -0.015625 0.265625 -0.09375q0.09375 -0.078125 0.09375 -0.296875l0 -2.46875q0 -0.65625 -0.203125 -0.96875q-0.203125 -0.328125 -0.78125 -0.328125q-0.5625 0 -0.90625 0.359375q-0.328125 0.359375 -0.328125 0.84375l0 2.90625l0.671875 0.046875l0 0.953125l-2.84375 0l0 -0.90625l0.390625 -0.046875q0.171875 -0.015625 0.25 -0.09375q0.09375 -0.078125 0.09375 -0.296875l0 -3.328125q0 -0.359375 -0.296875 -0.375l-0.5 -0.03125zm7.684433 3.734375l1.9375 -5.59375l-0.765625 -0.078125l0 -0.96875l2.75 0l2.453125 6.921875l0.015625 0l0.6875 0.046875l0 1.015625l-3.109375 0l0 -0.96875l0.46875 -0.046875q0.203125 -0.015625 0.265625 -0.078125q0.078125 -0.078125 0.015625 -0.28125l-0.25 -0.734375l-2.8125 0l-0.359375 1.046875l0.71875 0.046875l0 1.015625l-2.765625 0l0 -0.96875l0.390625 -0.046875q0.25 -0.015625 0.359375 -0.328125zm3.0625 -5.296875l-1.109375 3.421875l2.21875 0l-1.078125 -3.421875l-0.03125 0zm4.3971863 1.5625l0 -0.953125l2.03125 0l0 0.78125q0.296875 -0.453125 0.8125 -0.703125q0.53125 -0.265625 1.125 -0.265625q0.90625 0 1.40625 0.53125q0.5 0.515625 0.5 1.5625l0 3.125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.390625 -0.046875q0.1875 -0.015625 0.265625 -0.09375q0.09375 -0.078125 0.09375 -0.296875l0 -2.46875q0 -0.65625 -0.203125 -0.96875q-0.203125 -0.328125 -0.78125 -0.328125q-0.5625 0 -0.90625 0.359375q-0.328125 0.359375 -0.328125 0.84375l0 2.90625l0.671875 0.046875l0 0.953125l-2.84375 0l0 -0.90625l0.390625 -0.046875q0.171875 -0.015625 0.25 -0.09375q0.09375 -0.078125 0.09375 -0.296875l0 -3.328125q0 -0.359375 -0.296875 -0.375l-0.5 -0.03125zm10.418808 -0.046875q-0.296875 -0.109375 -0.703125 -0.109375q-0.40625 0 -0.65625 0.1875q-0.25 0.1875 -0.25 0.46875q0 0.265625 0.09375 0.421875q0.09375 0.140625 0.28125 0.25q0.296875 0.15625 0.71875 0.265625q0.421875 0.109375 0.625 0.1875q0.203125 0.0625 0.5 0.21875q0.3125 0.15625 0.46875 0.328125q0.421875 0.453125 0.421875 1.140625q0 0.90625 -0.65625 1.421875q-0.640625 0.5 -1.65625 0.5q-1.46875 0 -2.203125 -0.375l0 -1.671875l1.078125 -0.078125l0 0.578125q0 0.53125 1.0 0.53125q1.015625 0 1.015625 -0.734375q0 -0.265625 -0.171875 -0.4375q-0.171875 -0.171875 -0.34375 -0.21875q-0.171875 -0.0625 -0.40625 -0.125q-0.234375 -0.0625 -0.453125 -0.125q-0.21875 -0.0625 -0.484375 -0.15625q-0.25 -0.109375 -0.5625 -0.296875q-0.609375 -0.390625 -0.609375 -1.3125q0 -0.9375 0.65625 -1.4375q0.65625 -0.515625 1.65625 -0.515625q1.0 0 1.984375 0.484375l0 1.4375l-1.078125 0.078125l0 -0.5q0 -0.296875 -0.265625 -0.40625zm6.240158 -0.671875l1.15625 0l1.25 4.46875l0.015625 0l0.828125 -3.71875l-0.609375 -0.03125l0 -0.953125l2.40625 0l0 0.921875l-0.390625 0.015625q-0.25 0.03125 -0.328125 0.296875l-1.140625 4.796875l-1.734375 0l-0.984375 -3.578125l-0.03125 0l-1.015625 3.578125l-1.765625 0l-1.21875 -4.78125q-0.0625 -0.171875 -0.125 -0.234375q-0.0625 -0.0625 -0.21875 -0.078125l-0.34375 -0.015625l0 -0.921875l2.75 0l0 0.953125l-0.671875 0.03125l0.828125 3.6875l0.015625 0l1.328125 -4.4375zm8.57132 -0.421875q0.90625 0 1.421875 0.4375q0.53125 0.4375 0.53125 1.234375q0 0.53125 -0.234375 0.9375q-0.21875 0.390625 -0.5625 0.625q-0.328125 0.21875 -0.8125 0.375q-0.796875 0.234375 -1.8125 0.234375q0.046875 0.640625 0.40625 1.03125q0.359375 0.390625 1.09375 0.390625q0.75 0 1.5 -0.53125l0.453125 0.96875q-0.25 0.21875 -0.8125 0.453125q-0.5625 0.21875 -1.296875 0.21875q-1.46875 0 -2.15625 -0.8125q-0.6875 -0.8125 -0.6875 -2.21875q0 -1.421875 0.78125 -2.375q0.78125 -0.96875 2.1875 -0.96875zm-0.546875 2.75q0.4375 -0.078125 0.8125 -0.34375q0.375 -0.28125 0.375 -0.65625q0 -0.734375 -0.71875 -0.734375q-0.671875 0 -1.03125 0.546875q-0.359375 0.53125 -0.40625 1.265625q0.53125 0 0.96875 -0.078125zm4.0892944 2.125l0 -3.328125q0 -0.1875 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.09375l-0.5 -0.03125l0 -0.96875l2.015625 0l0 0.890625q0.21875 -0.46875 0.671875 -0.765625q0.453125 -0.3125 1.046875 -0.3125q0.609375 0 1.140625 0.265625l0 1.796875l-1.109375 0.078125l0 -0.546875q0 -0.25 -0.125 -0.3125q-0.125 -0.046875 -0.328125 -0.046875q-0.46875 0 -0.78125 0.34375q-0.3125 0.328125 -0.3125 0.859375l0 2.78125l1.046875 0.046875l0 0.953125l-3.21875 0l0 -0.90625l0.40625 -0.046875q0.171875 -0.015625 0.25 -0.09375q0.09375 -0.078125 0.09375 -0.296875zm7.809662 -4.875q0.90625 0 1.421875 0.4375q0.53125 0.4375 0.53125 1.234375q0 0.53125 -0.234375 0.9375q-0.21875 0.390625 -0.5625 0.625q-0.328125 0.21875 -0.8125 0.375q-0.796875 0.234375 -1.8125 0.234375q0.046875 0.640625 0.40625 1.03125q0.359375 0.390625 1.09375 0.390625q0.75 0 1.5 -0.53125l0.453125 0.96875q-0.25 0.21875 -0.8125 0.453125q-0.5625 0.21875 -1.296875 0.21875q-1.46875 0 -2.15625 -0.8125q-0.6875 -0.8125 -0.6875 -2.21875q0 -1.421875 0.78125 -2.375q0.78125 -0.96875 2.1875 -0.96875zm-0.546875 2.75q0.4375 -0.078125 0.8125 -0.34375q0.375 -0.28125 0.375 -0.65625q0 -0.734375 -0.71875 -0.734375q-0.671875 0 -1.03125 0.546875q-0.359375 0.53125 -0.40625 1.265625q0.53125 0 0.96875 -0.078125zm4.0892944 2.125l0 -3.328125q0 -0.1875 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.09375l-0.5 -0.03125l0 -0.96875l2.015625 0l0 0.890625q0.21875 -0.46875 0.671875 -0.765625q0.453125 -0.3125 1.046875 -0.3125q0.609375 0 1.140625 0.265625l0 1.796875l-1.109375 0.078125l0 -0.546875q0 -0.25 -0.125 -0.3125q-0.125 -0.046875 -0.328125 -0.046875q-0.46875 0 -0.78125 0.34375q-0.3125 0.328125 -0.3125 0.859375l0 2.78125l1.046875 0.046875l0 0.953125l-3.21875 0l0 -0.90625l0.40625 -0.046875q0.171875 -0.015625 0.25 -0.09375q0.09375 -0.078125 0.09375 -0.296875z" fill-rule="nonzero"/><path fill="#ff9900" d="m6.3882904 65.90535l0 0c0 -2.968769 2.4066648 -5.375435 5.375436 -5.375435l84.80818 0c1.4256592 0 2.792923 0.5663376 3.8010101 1.5744286c1.0080872 1.008091 1.5744247 2.3753548 1.5744247 3.8010063l0 21.501099c0 2.968773 -2.406662 5.375435 -5.375435 5.375435l-84.80818 0c-2.968771 0 -5.375436 -2.406662 -5.375436 -5.375435z" fill-rule="evenodd"/><path fill="#000000" d="m21.444216 81.175896l0 -0.96875l0.421875 -0.046875q0.3125 -0.015625 0.3125 -0.328125l0 -5.578125l-0.6875 -0.03125l0 -1.03125l2.171875 0l3.390625 5.515625l0.03125 0l0 -4.453125l-0.796875 -0.03125l0 -1.03125l2.921875 0l0 0.96875l-0.421875 0.046875q-0.296875 0.015625 -0.296875 0.375l0 6.59375l-1.40625 0l-3.46875 -5.59375l-0.015625 0l0 4.53125l0.78125 0.046875l0 1.015625l-2.9375 0zm14.147554 0l-5.687502 0l0 -0.96875l0.390625 -0.046875q0.34375 -0.03125 0.34375 -0.328125l0 -5.578125l-0.6875 -0.03125l0 -1.03125l2.906252 0l0 0.96875l-0.40625 0.046875q-0.34375 0.03125 -0.34375 0.375l0 5.359375l1.984375 0q0.171875 0 0.21875 -0.09375q0.0625 -0.09375 0.0625 -0.296875l0 -0.828125l1.21875 0.078125l0 2.375zm5.600296 -6.703125q-0.140625 -0.234375 -0.984375 -0.234375q-1.125 0 -1.703125 0.765625q-0.5625 0.75 -0.5625 2.171875q0 2.96875 2.21875 2.96875q0.03125 0 0.3125 0q0.28125 0 0.53125 -0.078125q0.25 -0.078125 0.296875 -0.15625q0.0625 -0.09375 0.0625 -0.265625l0 -0.96875l1.21875 0.09375l0 2.0625q-0.9375 0.5 -2.375 0.5q-1.84375 0 -2.8125 -1.03125q-0.953125 -1.03125 -0.953125 -3.078125q0 -1.125 0.3125 -1.953125q0.328125 -0.84375 0.890625 -1.3125q1.078125 -0.90625 2.5625 -0.90625q1.203125 0 2.25 0.5l0 2.03125l-1.21875 0.078125l0 -0.921875q0 -0.1875 -0.046875 -0.265625zm4.536545 -1.859375l0 7.5625l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.390625 -0.046875q0.34375 -0.03125 0.34375 -0.375l0 -5.921875q0 -0.265625 -0.296875 -0.28125l-0.421875 -0.03125l0 -1.0l2.15625 0zm6.8103943 2.765625l0 4.46875q0 0.1875 0.0625 0.265625q0.078125 0.0625 0.234375 0.078125l0.390625 0.03125l0 0.953125l-1.90625 0l0 -0.703125l-0.03125 0q-0.609375 0.859375 -1.65625 0.859375q-1.21875 0 -1.8125 -0.78125q-0.578125 -0.78125 -0.578125 -2.140625q0 -1.625 0.78125 -2.53125q0.796875 -0.921875 2.375 -0.921875q1.03125 0 2.140625 0.421875zm-1.390625 3.890625l0 -3.109375q-0.328125 -0.15625 -0.921875 -0.15625q-0.8125 0 -1.171875 0.65625q-0.359375 0.640625 -0.359375 1.703125q0 1.953125 1.25 1.953125q0.53125 0 0.859375 -0.3125q0.34375 -0.328125 0.34375 -0.734375zm5.9602966 -3.21875q-0.296875 -0.109375 -0.703125 -0.109375q-0.40625 0 -0.65625 0.1875q-0.25 0.1875 -0.25 0.46875q0 0.265625 0.09375 0.421875q0.09375 0.140625 0.28125 0.25q0.296875 0.15625 0.71875 0.265625q0.421875 0.109375 0.625 0.1875q0.203125 0.0625 0.5 0.21875q0.3125 0.15625 0.46875 0.328125q0.421875 0.453125 0.421875 1.140625q0 0.90625 -0.65625 1.421875q-0.640625 0.5 -1.65625 0.5q-1.46875 0 -2.203125 -0.375l0 -1.671875l1.078125 -0.078125l0 0.578125q0 0.53125 1.0 0.53125q1.015625 0 1.015625 -0.734375q0 -0.265625 -0.171875 -0.4375q-0.171875 -0.171875 -0.34375 -0.21875q-0.171875 -0.0625 -0.40625 -0.125q-0.234375 -0.0625 -0.453125 -0.125q-0.21875 -0.0625 -0.484375 -0.15625q-0.25 -0.109375 -0.5625 -0.296875q-0.609375 -0.390625 -0.609375 -1.3125q0 -0.9375 0.65625 -1.4375q0.65625 -0.515625 1.65625 -0.515625q1.0 0 1.984375 0.484375l0 1.4375l-1.078125 0.078125l0 -0.5q0 -0.296875 -0.265625 -0.40625zm5.412033 0q-0.296875 -0.109375 -0.703125 -0.109375q-0.40625 0 -0.65625 0.1875q-0.25 0.1875 -0.25 0.46875q0 0.265625 0.09375 0.421875q0.09375 0.140625 0.28125 0.25q0.296875 0.15625 0.71875 0.265625q0.421875 0.109375 0.625 0.1875q0.203125 0.0625 0.5 0.21875q0.3125 0.15625 0.46875 0.328125q0.4218712 0.453125 0.4218712 1.140625q0 0.90625 -0.6562462 1.421875q-0.640625 0.5 -1.65625 0.5q-1.46875 0 -2.203125 -0.375l0 -1.671875l1.078125 -0.078125l0 0.578125q0 0.53125 1.0 0.53125q1.015625 0 1.015625 -0.734375q0 -0.265625 -0.171875 -0.4375q-0.171875 -0.171875 -0.34375 -0.21875q-0.171875 -0.0625 -0.40625 -0.125q-0.234375 -0.0625 -0.453125 -0.125q-0.21875 -0.0625 -0.484375 -0.15625q-0.25 -0.109375 -0.5625 -0.296875q-0.609375 -0.390625 -0.609375 -1.3125q0 -0.9375 0.65625 -1.4375q0.65625 -0.515625 1.65625 -0.515625q1.0 0 1.984375 0.484375l0 1.4375l-1.078125 0.078125l0 -0.5q0 -0.296875 -0.265625 -0.40625zm4.4276543 -0.90625l0 5.03125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.40625 -0.046875q0.34375 -0.03125 0.34375 -0.375l0 -3.34375q0 -0.203125 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.078125l-0.421875 -0.03125l0 -0.984375l2.140625 0zm-1.515625 -1.03125q-0.25 -0.25 -0.25 -0.640625q0 -0.390625 0.25 -0.640625q0.265625 -0.265625 0.671875 -0.265625q0.40625 0 0.65625 0.265625q0.265625 0.25 0.265625 0.640625q0 0.390625 -0.265625 0.640625q-0.25 0.234375 -0.65625 0.234375q-0.40625 0 -0.671875 -0.234375zm3.6614075 5.71875l0 -3.6875l-1.03125 0l0 -1.0l1.03125 0l0 -0.453125q0 -1.1875 0.53125 -1.703125q0.515625 -0.484375 1.484375 -0.484375q0.96875 0 1.71875 0.421875l-0.34375 0.953125q-0.578125 -0.296875 -1.078125 -0.296875q-0.5 0 -0.703125 0.25q-0.1875 0.234375 -0.1875 0.75l0 0.5625l1.78125 0l0 1.0l-1.78125 0l0 4.015625l1.015625 0.0625l0 0.953125l-3.1875 0l0 -0.90625l0.390625 -0.046875q0.1875 -0.015625 0.265625 -0.09375q0.09375 -0.078125 0.09375 -0.296875zm5.846283 -4.6875l0 5.03125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.40625 -0.046875q0.34375 -0.03125 0.34375 -0.375l0 -3.34375q0 -0.203125 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.078125l-0.421875 -0.03125l0 -0.984375l2.140625 0zm-1.515625 -1.03125q-0.25 -0.25 -0.25 -0.640625q0 -0.390625 0.25 -0.640625q0.265625 -0.265625 0.671875 -0.265625q0.40625 0 0.65625 0.265625q0.265625 0.25 0.265625 0.640625q0 0.390625 -0.265625 0.640625q-0.25 0.234375 -0.65625 0.234375q-0.40625 0 -0.671875 -0.234375zm5.9114075 0.84375q0.90625 0 1.421875 0.4375q0.53125 0.4375 0.53125 1.234375q0 0.53125 -0.234375 0.9375q-0.21875 0.390625 -0.5625 0.625q-0.328125 0.21875 -0.8125 0.375q-0.796875 0.234375 -1.8125 0.234375q0.046875 0.640625 0.40625 1.03125q0.359375 0.390625 1.09375 0.390625q0.75 0 1.5 -0.53125l0.453125 0.96875q-0.25 0.21875 -0.8125 0.453125q-0.5625 0.21875 -1.296875 0.21875q-1.46875 0 -2.15625 -0.8125q-0.6875 -0.8125 -0.6875 -2.21875q0 -1.421875 0.78125 -2.375q0.78125 -0.96875 2.1875 -0.96875zm-0.546875 2.75q0.4375 -0.078125 0.8125 -0.34375q0.375 -0.28125 0.375 -0.65625q0 -0.734375 -0.71875 -0.734375q-0.671875 0 -1.03125 0.546875q-0.359375 0.53125 -0.40625 1.265625q0.53125 0 0.96875 -0.078125zm4.0892944 2.125l0 -3.328125q0 -0.1875 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.09375l-0.5 -0.03125l0 -0.96875l2.015625 0l0 0.890625q0.21875 -0.46875 0.671875 -0.765625q0.453125 -0.3125 1.046875 -0.3125q0.609375 0 1.140625 0.265625l0 1.796875l-1.109375 0.078125l0 -0.546875q0 -0.25 -0.125 -0.3125q-0.125 -0.046875 -0.328125 -0.046875q-0.46875 0 -0.78125 0.34375q-0.3125 0.328125 -0.3125 0.859375l0 2.78125l1.046875 0.046875l0 0.953125l-3.21875 0l0 -0.90625l0.40625 -0.046875q0.171875 -0.015625 0.25 -0.09375q0.09375 -0.078125 0.09375 -0.296875z" fill-rule="nonzero"/><path fill="#ff9900" d="m-1.4698163E-5 123.44314l0 0c0 -2.968773 2.4066646 -5.375435 5.375436 -5.375435l97.186134 0c1.4256592 0 2.792923 0.5663376 3.8010101 1.5744247c1.0080872 1.0080948 1.5744324 2.3753586 1.5744324 3.8010101l0 21.501091c0 2.9687805 -2.4066696 5.3754425 -5.3754425 5.3754425l-97.186134 0c-2.9687712 0 -5.375436 -2.406662 -5.375436 -5.3754425z" fill-rule="evenodd"/><path fill="#000000" d="m14.727401 130.83994l2.999999 0q2.296875 0 2.296875 1.78125q0 0.984375 -0.984375 1.5q0.640625 0.15625 0.984375 0.59375q0.359375 0.421875 0.359375 1.09375q0 1.03125 -0.703125 1.578125q-0.6875 0.546875 -2.015625 0.546875l-2.999999 0l0 -0.859375l0.359375 -0.03125q0.296875 -0.03125 0.296875 -0.296875l0 -4.96875l-0.59375 -0.03125l0 -0.90625zm1.906249 3.859375l0 2.234375l1.0 0q1.375 0 1.375 -1.125q0 -0.546875 -0.34375 -0.828125q-0.34375 -0.28125 -0.984375 -0.28125l-1.046875 0zm0 -2.875l0 1.9375l0.78125 0q0.65625 0 0.96875 -0.265625q0.328125 -0.28125 0.328125 -0.765625q0 -0.5 -0.328125 -0.703125q-0.3125 -0.203125 -0.890625 -0.203125l-0.859375 0zm7.1182556 0.59375q0.796875 0 1.265625 0.390625q0.46875 0.390625 0.46875 1.09375q0 0.46875 -0.203125 0.828125q-0.203125 0.34375 -0.5 0.546875q-0.296875 0.203125 -0.734375 0.328125q-0.703125 0.21875 -1.609375 0.21875q0.046875 0.5625 0.359375 0.90625q0.3125 0.34375 0.96875 0.34375q0.671875 0 1.328125 -0.46875l0.40625 0.875q-0.203125 0.1875 -0.71875 0.390625q-0.5 0.203125 -1.15625 0.203125q-1.296875 0 -1.90625 -0.71875q-0.609375 -0.71875 -0.609375 -1.96875q0 -1.265625 0.6875 -2.109375q0.703125 -0.859375 1.953125 -0.859375zm-0.484375 2.4375q0.390625 -0.078125 0.71875 -0.3125q0.328125 -0.25 0.328125 -0.578125q0 -0.640625 -0.640625 -0.640625q-0.59375 0 -0.921875 0.46875q-0.3125 0.46875 -0.34375 1.140625q0.46875 -0.015625 0.859375 -0.078125zm3.6264343 1.890625l0 -2.953125q0 -0.171875 -0.0625 -0.234375q-0.0625 -0.078125 -0.203125 -0.09375l-0.453125 -0.03125l0 -0.859375l1.796875 0l0 0.796875q0.203125 -0.40625 0.59375 -0.671875q0.40625 -0.28125 0.9375 -0.28125q0.53125 0 1.015625 0.234375l0 1.578125l-1.0 0.078125l0 -0.484375q0 -0.21875 -0.109375 -0.265625q-0.109375 -0.0625 -0.28125 -0.0625q-0.421875 0 -0.703125 0.3125q-0.265625 0.296875 -0.265625 0.765625l0 2.46875l0.921875 0.046875l0 0.84375l-2.84375 0l0 -0.8125l0.34375 -0.03125q0.15625 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625zm5.43886 -5.296875l0.703125 0l0 1.125l1.40625 0l-0.109375 0.890625l-1.296875 0l0 2.8125q0 0.46875 0.15625 0.65625q0.171875 0.171875 0.53125 0.171875q0.359375 0 0.71875 -0.21875l0.328125 0.828125q-0.53125 0.375 -1.359375 0.375q-0.484375 0 -0.8125 -0.125q-0.328125 -0.125 -0.484375 -0.265625q-0.1562519 -0.15625 -0.2500019 -0.4375q-0.078125 -0.296875 -0.09375 -0.453125q-0.015625 -0.171875 -0.015625 -0.5l0 -2.84375l-0.75 0l0.109375 -0.78125q0.5 -0.03125 0.75 -0.34375q0.2656269 -0.328125 0.4687519 -0.890625zm2.8299408 6.484375l0 -0.859375l0.375 -0.03125q0.28125 -0.03125 0.28125 -0.296875l0 -4.96875l-0.609375 -0.03125l0 -0.90625l1.921875 0l3.015625 4.90625l0.015625 0l0 -3.96875l-0.703125 -0.03125l0 -0.90625l2.609375 0l0 0.875l-0.390625 0.03125q-0.25 0.015625 -0.25 0.328125l0 5.859375l-1.25 0l-3.078125 -4.96875l-0.015625 0l0 4.03125l0.6875 0.03125l0 0.90625l-2.609375 0zm12.575226 0l-5.0625 0l0 -0.859375l0.34375 -0.03125q0.3125 -0.03125 0.3125 -0.296875l0 -4.96875l-0.609375 -0.03125l0 -0.90625l2.578125 0l0 0.875l-0.359375 0.03125q-0.3125 0.03125 -0.3125 0.34375l0 4.75l1.78125 0q0.140625 0 0.1875 -0.078125q0.046875 -0.078125 0.046875 -0.265625l0 -0.75l1.09375 0.078125l0 2.109375zm4.956833 -5.953125q-0.109375 -0.203125 -0.875 -0.203125q-0.984375 0 -1.5 0.671875q-0.5 0.671875 -0.5 1.9375q0 2.625 1.96875 2.625q0.03125 0 0.28125 0q0.25 0 0.46875 -0.0625q0.21875 -0.078125 0.265625 -0.140625q0.046875 -0.078125 0.046875 -0.234375l0 -0.859375l1.078125 0.078125l0 1.84375q-0.828125 0.4375 -2.109375 0.4375q-1.625 0 -2.484375 -0.90625q-0.859375 -0.921875 -0.859375 -2.734375q0 -1.0 0.28125 -1.75q0.296875 -0.75 0.78125 -1.171875q0.96875 -0.796875 2.28125 -0.796875q1.078125 0 2.015625 0.4375l0 1.8125l-1.09375 0.078125l0 -0.828125q0 -0.15625 -0.046875 -0.234375zm4.0401306 -1.640625l0 6.703125l0.59375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.359375 -0.03125q0.296875 -0.03125 0.296875 -0.34375l0 -5.25q0 -0.234375 -0.265625 -0.25l-0.359375 -0.03125l0 -0.875l1.90625 0zm6.050598 2.453125l0 3.953125q0 0.171875 0.046875 0.234375q0.0625 0.0625 0.21875 0.078125l0.34375 0.015625l0 0.859375l-1.703125 0l0 -0.625l-0.03125 0q-0.53125 0.765625 -1.453125 0.765625q-1.09375 0 -1.625 -0.6875q-0.515625 -0.703125 -0.515625 -1.90625q0 -1.453125 0.703125 -2.25q0.703125 -0.8125 2.109375 -0.8125q0.90625 0 1.90625 0.375zm-1.234375 3.453125l0 -2.765625q-0.296875 -0.140625 -0.828125 -0.140625q-0.71875 0 -1.046875 0.578125q-0.3125 0.578125 -0.3125 1.515625q0 1.734375 1.109375 1.734375q0.46875 0 0.765625 -0.28125q0.3125 -0.28125 0.3125 -0.640625zm5.2765274 -2.859375q-0.25 -0.09375 -0.625 -0.09375q-0.359375 0 -0.578125 0.171875q-0.21874237 0.15625 -0.21874237 0.390625q0 0.234375 0.078125 0.375q0.09374237 0.140625 0.26561737 0.234375q0.265625 0.140625 0.625 0.234375q0.375 0.09375 0.5625 0.15625q0.1875 0.0625 0.453125 0.203125q0.265625 0.140625 0.40625 0.296875q0.375 0.390625 0.375 1.015625q0 0.796875 -0.578125 1.25q-0.578125 0.453125 -1.484375 0.453125q-1.2968674 0 -1.9531174 -0.328125l0 -1.484375l0.953125 -0.078125l0 0.515625q0 0.46875 0.8906174 0.46875q0.90625 0 0.90625 -0.65625q0 -0.234375 -0.15625 -0.375q-0.15625 -0.15625 -0.3125 -0.203125q-0.140625 -0.0625 -0.359375 -0.109375q-0.203125 -0.046875 -0.40625 -0.09375q-0.18749237 -0.0625 -0.42186737 -0.15625q-0.21875 -0.09375 -0.5 -0.265625q-0.546875 -0.34375 -0.546875 -1.171875q0 -0.828125 0.578125 -1.265625q0.59375 -0.453125 1.4843674 -0.453125q0.890625 0 1.765625 0.421875l0 1.28125l-0.953125 0.078125l0 -0.453125q0 -0.265625 -0.25 -0.359375zm4.806015 0q-0.25 -0.09375 -0.625 -0.09375q-0.359375 0 -0.578125 0.171875q-0.21875 0.15625 -0.21875 0.390625q0 0.234375 0.078125 0.375q0.09375 0.140625 0.265625 0.234375q0.265625 0.140625 0.625 0.234375q0.375 0.09375 0.5625 0.15625q0.1875 0.0625 0.453125 0.203125q0.265625 0.140625 0.40625 0.296875q0.375 0.390625 0.375 1.015625q0 0.796875 -0.578125 1.25q-0.578125 0.453125 -1.484375 0.453125q-1.296875 0 -1.953125 -0.328125l0 -1.484375l0.953125 -0.078125l0 0.515625q0 0.46875 0.890625 0.46875q0.90625 0 0.90625 -0.65625q0 -0.234375 -0.15625 -0.375q-0.15625 -0.15625 -0.3125 -0.203125q-0.140625 -0.0625 -0.359375 -0.109375q-0.203125 -0.046875 -0.40625 -0.09375q-0.1875 -0.0625 -0.421875 -0.15625q-0.21875 -0.09375 -0.5 -0.265625q-0.546875 -0.34375 -0.546875 -1.171875q0 -0.828125 0.578125 -1.265625q0.59375 -0.453125 1.484375 -0.453125q0.890625 0 1.765625 0.421875l0 1.28125l-0.953125 0.078125l0 -0.453125q0 -0.265625 -0.25 -0.359375zm3.94664 -0.8125l0 4.46875l0.609375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.34375 -0.03125q0.3125 -0.03125 0.3125 -0.34375l0 -2.953125q0 -0.1875 -0.0625 -0.25q-0.0625 -0.0625 -0.203125 -0.0625l-0.375 -0.015625l0 -0.890625l1.90625 0zm-1.34375 -0.921875q-0.21875 -0.21875 -0.21875 -0.5625q0 -0.34375 0.21875 -0.5625q0.21875 -0.234375 0.578125 -0.234375q0.375 0 0.59375 0.234375q0.234375 0.21875 0.234375 0.5625q0 0.34375 -0.234375 0.5625q-0.21875 0.21875 -0.59375 0.21875q-0.359375 0 -0.578125 -0.21875zm3.2324066 5.09375l0 -3.28125l-0.90625 0l0 -0.890625l0.90625 0l0 -0.40625q0 -1.046875 0.484375 -1.5q0.453125 -0.4375 1.3125 -0.4375q0.859375 0 1.53125 0.375l-0.3125 0.84375q-0.5 -0.265625 -0.953125 -0.265625q-0.453125 0 -0.625 0.21875q-0.15625 0.21875 -0.15625 0.671875l0 0.5l1.578125 0l0 0.890625l-1.578125 0l0 3.578125l0.890625 0.03125l0 0.859375l-2.828125 0l0 -0.8125l0.359375 -0.03125q0.15625 -0.015625 0.21875 -0.078125q0.078125 -0.078125 0.078125 -0.265625zm5.208481 -4.171875l0 4.46875l0.609375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.34375 -0.03125q0.3125 -0.03125 0.3125 -0.34375l0 -2.953125q0 -0.1875 -0.0625 -0.25q-0.0625 -0.0625 -0.203125 -0.0625l-0.375 -0.015625l0 -0.890625l1.90625 0zm-1.34375 -0.921875q-0.21875 -0.21875 -0.21875 -0.5625q0 -0.34375 0.21875 -0.5625q0.21875 -0.234375 0.578125 -0.234375q0.375 0 0.59375 0.234375q0.234375 0.21875 0.234375 0.5625q0 0.34375 -0.234375 0.5625q-0.21875 0.21875 -0.59375 0.21875q-0.359375 0 -0.578125 -0.21875zm5.2480316 0.765625q0.796875 0 1.265625 0.390625q0.46875 0.390625 0.46875 1.09375q0 0.46875 -0.203125 0.828125q-0.203125 0.34375 -0.5 0.546875q-0.296875 0.203125 -0.734375 0.328125q-0.703125 0.21875 -1.609375 0.21875q0.046875 0.5625 0.359375 0.90625q0.3125 0.34375 0.96875 0.34375q0.671875 0 1.328125 -0.46875l0.40625 0.875q-0.203125 0.1875 -0.71875 0.390625q-0.5 0.203125 -1.15625 0.203125q-1.296875 0 -1.90625 -0.71875q-0.609375 -0.71875 -0.609375 -1.96875q0 -1.265625 0.6875 -2.109375q0.703125 -0.859375 1.953125 -0.859375zm-0.484375 2.4375q0.390625 -0.078125 0.71875 -0.3125q0.328125 -0.25 0.328125 -0.578125q0 -0.640625 -0.640625 -0.640625q-0.59375 0 -0.921875 0.46875q-0.3125 0.46875 -0.34375 1.140625q0.46875 -0.015625 0.859375 -0.078125zm3.6264343 1.890625l0 -2.953125q0 -0.171875 -0.0625 -0.234375q-0.0625 -0.078125 -0.203125 -0.09375l-0.453125 -0.03125l0 -0.859375l1.796875 0l0 0.796875q0.203125 -0.40625 0.59375 -0.671875q0.40625 -0.28125 0.9375 -0.28125q0.53125 0 1.015625 0.234375l0 1.578125l-1.0 0.078125l0 -0.484375q0 -0.21875 -0.109375 -0.265625q-0.109375 -0.0625 -0.28125 -0.0625q-0.421875 0 -0.703125 0.3125q-0.265625 0.296875 -0.265625 0.765625l0 2.46875l0.921875 0.046875l0 0.84375l-2.84375 0l0 -0.8125l0.34375 -0.03125q0.15625 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625z" fill-rule="nonzero"/><path fill="#ff9900" d="m357.90506 65.907974l0 0c0 -2.968769 2.4066467 -5.375435 5.3754272 -5.375435l116.90268 0c1.4256592 0 2.7929077 0.5663376 3.8009949 1.5744286c1.0080872 1.008091 1.5744324 2.3753548 1.5744324 3.8010063l0 21.501099c0 2.968773 -2.4066467 5.375435 -5.3754272 5.375435l-116.90268 0c-2.9687805 0 -5.3754272 -2.406662 -5.3754272 -5.375435z" fill-rule="evenodd"/><path fill="#000000" d="m371.1282 73.194145l3.390625 0q2.578125 0 2.578125 2.015625q0 1.09375 -1.09375 1.671875q0.703125 0.1875 1.109375 0.671875q0.40625 0.46875 0.40625 1.234375q0 1.171875 -0.796875 1.78125q-0.78125 0.609375 -2.28125 0.609375l-3.375 0l0 -0.96875l0.390625 -0.046875q0.359375 -0.03125 0.359375 -0.328125l0 -5.578125l-0.6875 -0.03125l0 -1.03125zm2.140625 4.34375l0 2.515625l1.140625 0q1.546875 0 1.546875 -1.28125q0 -0.59375 -0.390625 -0.90625q-0.375 -0.328125 -1.109375 -0.328125l-1.1875 0zm0 -3.234375l0 2.171875l0.890625 0q0.75 0 1.09375 -0.296875q0.359375 -0.3125 0.359375 -0.859375q0 -0.5625 -0.359375 -0.78125q-0.34375 -0.234375 -1.0 -0.234375l-0.984375 0zm10.395935 1.078125l0 4.46875q0 0.1875 0.0625 0.265625q0.078125 0.0625 0.234375 0.078125l0.390625 0.03125l0 0.953125l-1.90625 0l0 -0.703125l-0.03125 0q-0.609375 0.859375 -1.65625 0.859375q-1.21875 0 -1.8125 -0.78125q-0.578125 -0.78125 -0.578125 -2.140625q0 -1.625 0.78125 -2.53125q0.796875 -0.921875 2.375 -0.921875q1.03125 0 2.140625 0.421875zm-1.390625 3.890625l0 -3.109375q-0.328125 -0.15625 -0.921875 -0.15625q-0.8125 0 -1.171875 0.65625q-0.359375 0.640625 -0.359375 1.703125q0 1.953125 1.25 1.953125q0.53125 0 0.859375 -0.3125q0.34375 -0.328125 0.34375 -0.734375zm5.9602966 -3.21875q-0.296875 -0.109375 -0.703125 -0.109375q-0.40625 0 -0.65625 0.1875q-0.25 0.1875 -0.25 0.46875q0 0.265625 0.09375 0.421875q0.09375 0.140625 0.28125 0.25q0.296875 0.15625 0.71875 0.265625q0.421875 0.109375 0.625 0.1875q0.203125 0.0625 0.5 0.21875q0.3125 0.15625 0.46875 0.328125q0.421875 0.453125 0.421875 1.140625q0 0.90625 -0.65625 1.421875q-0.640625 0.5 -1.65625 0.5q-1.46875 0 -2.203125 -0.375l0 -1.671875l1.078125 -0.078125l0 0.578125q0 0.53125 1.0 0.53125q1.015625 0 1.015625 -0.734375q0 -0.265625 -0.171875 -0.4375q-0.171875 -0.171875 -0.34375 -0.21875q-0.171875 -0.0625 -0.40625 -0.125q-0.234375 -0.0625 -0.453125 -0.125q-0.21875 -0.0625 -0.484375 -0.15625q-0.25 -0.109375 -0.5625 -0.296875q-0.609375 -0.390625 -0.609375 -1.3125q0 -0.9375 0.65625 -1.4375q0.65625 -0.515625 1.65625 -0.515625q1.0 0 1.984375 0.484375l0 1.4375l-1.078125 0.078125l0 -0.5q0 -0.296875 -0.265625 -0.40625zm5.271393 -1.09375q0.90625 0 1.421875 0.4375q0.53125 0.4375 0.53125 1.234375q0 0.53125 -0.234375 0.9375q-0.21875 0.390625 -0.5625 0.625q-0.328125 0.21875 -0.8125 0.375q-0.796875 0.234375 -1.8125 0.234375q0.046875 0.640625 0.40625 1.03125q0.359375 0.390625 1.09375 0.390625q0.75 0 1.5 -0.53125l0.453125 0.96875q-0.25 0.21875 -0.8125 0.453125q-0.5625 0.21875 -1.296875 0.21875q-1.46875 0 -2.15625 -0.8125q-0.6875 -0.8125 -0.6875 -2.21875q0 -1.421875 0.78125 -2.375q0.78125 -0.96875 2.1875 -0.96875zm-0.546875 2.75q0.4375 -0.078125 0.8125 -0.34375q0.375 -0.28125 0.375 -0.65625q0 -0.734375 -0.71875 -0.734375q-0.671875 0 -1.03125 0.546875q-0.359375 0.53125 -0.40625 1.265625q0.53125 0 0.96875 -0.078125zm5.3549194 -3.078125l1.609375 4.828125l0.0625 0l1.6875 -5.1875l-0.734375 -0.046875l0 -1.03125l2.75 0l0 0.96875l-0.390625 0.046875q-0.15625 0.015625 -0.234375 0.09375q-0.0625 0.078125 -0.140625 0.265625l-2.265625 6.609375l-1.5625 0l-2.4375 -6.921875l-0.6875 -0.03125l0 -1.03125l3.09375 0l0 0.96875l-0.390625 0.046875q-0.375 0.015625 -0.375 0.28125q0 0.0625 0.015625 0.140625zm8.004303 0.515625l0 5.03125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.40625 -0.046875q0.34375 -0.03125 0.34375 -0.375l0 -3.34375q0 -0.203125 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.078125l-0.421875 -0.03125l0 -0.984375l2.140625 0zm-1.515625 -1.03125q-0.25 -0.25 -0.25 -0.640625q0 -0.390625 0.25 -0.640625q0.265625 -0.265625 0.671875 -0.265625q0.40625 0 0.65625 0.265625q0.265625 0.25 0.265625 0.640625q0 0.390625 -0.265625 0.640625q-0.25 0.234375 -0.65625 0.234375q-0.40625 0 -0.671875 -0.234375zm6.0520325 1.9375q-0.296875 -0.109375 -0.703125 -0.109375q-0.40625 0 -0.65625 0.1875q-0.25 0.1875 -0.25 0.46875q0 0.265625 0.09375 0.421875q0.09375 0.140625 0.28125 0.25q0.296875 0.15625 0.71875 0.265625q0.421875 0.109375 0.625 0.1875q0.203125 0.0625 0.5 0.21875q0.3125 0.15625 0.46875 0.328125q0.421875 0.453125 0.421875 1.140625q0 0.90625 -0.65625 1.421875q-0.640625 0.5 -1.65625 0.5q-1.46875 0 -2.203125 -0.375l0 -1.671875l1.078125 -0.078125l0 0.578125q0 0.53125 1.0 0.53125q1.015625 0 1.015625 -0.734375q0 -0.265625 -0.171875 -0.4375q-0.171875 -0.171875 -0.34375 -0.21875q-0.171875 -0.0625 -0.40625 -0.125q-0.234375 -0.0625 -0.453125 -0.125q-0.21875 -0.0625 -0.484375 -0.15625q-0.25 -0.109375 -0.5625 -0.296875q-0.609375 -0.390625 -0.609375 -1.3125q0 -0.9375 0.65625 -1.4375q0.65625 -0.515625 1.65625 -0.515625q1.0 0 1.984375 0.484375l0 1.4375l-1.078125 0.078125l0 -0.5q0 -0.296875 -0.265625 -0.40625zm4.4276733 -0.90625l0 5.03125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.40625 -0.046875q0.34375 -0.03125 0.34375 -0.375l0 -3.34375q0 -0.203125 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.078125l-0.421875 -0.03125l0 -0.984375l2.140625 0zm-1.515625 -1.03125q-0.25 -0.25 -0.25 -0.640625q0 -0.390625 0.25 -0.640625q0.265625 -0.265625 0.671875 -0.265625q0.40625 0 0.65625 0.265625q0.265625 0.25 0.265625 0.640625q0 0.390625 -0.265625 0.640625q-0.25 0.234375 -0.65625 0.234375q-0.40625 0 -0.671875 -0.234375zm5.6614075 6.1875q1.3125 0 1.3125 -2.109375q0 -1.078125 -0.296875 -1.625q-0.296875 -0.546875 -0.984375 -0.546875q-0.6875 0 -1.015625 0.53125q-0.328125 0.515625 -0.328125 1.4375q0 1.6875 0.625 2.125q0.28125 0.1875 0.6875 0.1875zm-2.765625 -2.15625q0 -0.875 0.265625 -1.53125q0.265625 -0.65625 0.71875 -1.0q0.84375 -0.65625 1.875 -0.65625q0.71875 0 1.21875 0.234375q0.5 0.234375 0.78125 0.546875q0.28125 0.296875 0.46875 0.890625q0.203125 0.578125 0.203125 1.359375q0 1.65625 -0.8125 2.515625q-0.796875 0.859375 -2.046875 0.859375q-1.25 0 -1.96875 -0.8125q-0.703125 -0.8125 -0.703125 -2.40625zm6.273529 -2.046875l0 -0.953125l2.03125 0l0 0.78125q0.296875 -0.453125 0.8125 -0.703125q0.53125 -0.265625 1.125 -0.265625q0.90625 0 1.40625 0.53125q0.5 0.515625 0.5 1.5625l0 3.125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.390625 -0.046875q0.1875 -0.015625 0.265625 -0.09375q0.09375 -0.078125 0.09375 -0.296875l0 -2.46875q0 -0.65625 -0.203125 -0.96875q-0.203125 -0.328125 -0.78125 -0.328125q-0.5625 0 -0.90625 0.359375q-0.328125 0.359375 -0.328125 0.84375l0 2.90625l0.671875 0.046875l0 0.953125l-2.84375 0l0 -0.90625l0.390625 -0.046875q0.171875 -0.015625 0.25 -0.09375q0.09375 -0.078125 0.09375 -0.296875l0 -3.328125q0 -0.359375 -0.296875 -0.375l-0.5 -0.03125zm8.881317 3.734375l0 -5.453125l-0.921875 0q-0.28125 0 -0.28125 0.375l0 0.5625l-1.21875 -0.078125l0 -2.046875l6.265625 0l0 2.046875l-1.21875 0.078125l0 -0.5625q0 -0.203125 -0.0625 -0.28125q-0.0625 -0.09375 -0.28125 -0.09375l-0.828125 0l0 5.734375l0.953125 0.046875l0 1.015625l-3.390625 0l0 -0.96875l0.640625 -0.046875q0.34375 -0.03125 0.34375 -0.328125zm9.807922 -4.453125l0 4.46875q0 0.1875 0.0625 0.265625q0.078125 0.0625 0.234375 0.078125l0.390625 0.03125l0 0.953125l-1.90625 0l0 -0.703125l-0.03125 0q-0.609375 0.859375 -1.65625 0.859375q-1.21875 0 -1.8125 -0.78125q-0.578125 -0.78125 -0.578125 -2.140625q0 -1.625 0.78125 -2.53125q0.796875 -0.921875 2.375 -0.921875q1.03125 0 2.140625 0.421875zm-1.390625 3.890625l0 -3.109375q-0.328125 -0.15625 -0.921875 -0.15625q-0.8125 0 -1.171875 0.65625q-0.359375 0.640625 -0.359375 1.703125q0 1.953125 1.25 1.953125q0.53125 0 0.859375 -0.3125q0.34375 -0.328125 0.34375 -0.734375zm5.9602966 -3.21875q-0.296875 -0.109375 -0.703125 -0.109375q-0.40625 0 -0.65625 0.1875q-0.25 0.1875 -0.25 0.46875q0 0.265625 0.09375 0.421875q0.09375 0.140625 0.28125 0.25q0.296875 0.15625 0.71875 0.265625q0.421875 0.109375 0.625 0.1875q0.203125 0.0625 0.5 0.21875q0.3125 0.15625 0.46875 0.328125q0.421875 0.453125 0.421875 1.140625q0 0.90625 -0.65625 1.421875q-0.640625 0.5 -1.65625 0.5q-1.46875 0 -2.203125 -0.375l0 -1.671875l1.078125 -0.078125l0 0.578125q0 0.53125 1.0 0.53125q1.015625 0 1.015625 -0.734375q0 -0.265625 -0.171875 -0.4375q-0.171875 -0.171875 -0.34375 -0.21875q-0.171875 -0.0625 -0.40625 -0.125q-0.234375 -0.0625 -0.453125 -0.125q-0.21875 -0.0625 -0.484375 -0.15625q-0.25 -0.109375 -0.5625 -0.296875q-0.609375 -0.390625 -0.609375 -1.3125q0 -0.9375 0.65625 -1.4375q0.65625 -0.515625 1.65625 -0.515625q1.0 0 1.984375 0.484375l0 1.4375l-1.078125 0.078125l0 -0.5q0 -0.296875 -0.265625 -0.40625zm3.0057678 3.78125l0 -5.84375q0 -0.1875 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.078125l-0.421875 -0.03125l0 -1.0l2.140625 0l0 5.203125q0.65625 -0.0625 1.140625 -0.40625q0.5 -0.34375 0.671875 -0.921875q0.046875 -0.109375 0.046875 -0.1875q0 -0.203125 -0.234375 -0.203125l-0.453125 -0.015625l0 -0.9375l2.484375 0l0 0.96875l-0.578125 0.015625q-0.15625 1.234375 -1.0625 1.9375l1.1875 2.109375l0.71875 0.046875l0 0.953125l-1.75 0l-1.359375 -2.5q-0.40625 0.109375 -0.8125 0.171875l0 1.328125l0 0l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.40625 -0.046875q0.171875 -0.015625 0.25 -0.09375q0.09375 -0.078125 0.09375 -0.296875zm6.343933 0l1.9375 -5.59375l-0.765625 -0.078125l0 -0.96875l2.75 0l2.453125 6.921875l0.015625 0l0.6875 0.046875l0 1.015625l-3.109375 0l0 -0.96875l0.46875 -0.046875q0.203125 -0.015625 0.265625 -0.078125q0.078125 -0.078125 0.015625 -0.28125l-0.25 -0.734375l-2.8125 0l-0.359375 1.046875l0.71875 0.046875l0 1.015625l-2.765625 0l0 -0.96875l0.390625 -0.046875q0.25 -0.015625 0.359375 -0.328125zm3.0625 -5.296875l-1.109375 3.421875l2.21875 0l-1.078125 -3.421875l-0.03125 0zm4.8915405 7.78125l0 -5.8125q0 -0.203125 -0.0625 -0.28125q-0.0625 -0.078125 -0.234375 -0.078125l-0.5 -0.03125l0 -0.96875l2.03125 0l0 0.78125q0.1875 -0.359375 0.6875 -0.65625q0.5 -0.3125 1.203125 -0.3125q2.171875 0 2.171875 3.046875q0 1.625 -0.71875 2.5q-0.703125 0.859375 -1.90625 0.859375q-0.71875 0 -1.28125 -0.359375l0 1.65625l0.96875 0.046875l0 0.953125l-3.09375 0l0 -0.90625l0.390625 -0.046875q0.1875 -0.015625 0.265625 -0.09375q0.078125 -0.078125 0.078125 -0.296875zm3.8125 -4.1875q0 -2.0 -1.21875 -2.0q-0.5 0 -0.859375 0.3125q-0.34375 0.296875 -0.34375 0.71875l0 2.75q0.453125 0.3125 1.109375 0.3125q0.65625 0 0.984375 -0.59375q0.328125 -0.609375 0.328125 -1.5zm4.5035706 -2.984375l0 5.03125l0.6875 0.046875l0 0.953125l-2.859375 0l0 -0.90625l0.40625 -0.046875q0.34375 -0.03125 0.34375 -0.375l0 -3.34375q0 -0.203125 -0.0625 -0.265625q-0.0625 -0.078125 -0.234375 -0.078125l-0.421875 -0.03125l0 -0.984375l2.140625 0zm-1.515625 -1.03125q-0.25 -0.25 -0.25 -0.640625q0 -0.390625 0.25 -0.640625q0.265625 -0.265625 0.671875 -0.265625q0.40625 0 0.65625 0.265625q0.265625 0.25 0.265625 0.640625q0 0.390625 -0.265625 0.640625q-0.25 0.234375 -0.65625 0.234375q-0.40625 0 -0.671875 -0.234375z" fill-rule="nonzero"/><path fill="#ff9900" d="m473.916 123.446304l0 0c0 -2.968773 2.4066772 -5.375435 5.375458 -5.375435l74.760925 0c1.4256592 0 2.7929077 0.5663376 3.8010254 1.5744247c1.0080566 1.0080872 1.5744019 2.375351 1.5744019 3.8010101l0 21.501099c0 2.9687653 -2.4066772 5.3754272 -5.3754272 5.3754272l-74.760925 0l0 0c-2.9687805 0 -5.375458 -2.406662 -5.375458 -5.3754272z" fill-rule="evenodd"/><path fill="#000000" d="m488.88403 131.34435q1.109375 0 1.859375 0.78125q0.75 0.78125 0.75 2.390625q0 1.609375 -0.75 2.4375q-0.734375 0.828125 -2.0 0.828125q-1.265625 0 -1.984375 -0.8125q-0.71875 -0.828125 -0.71875 -2.40625q0 -0.859375 0.25 -1.5q0.25 -0.65625 0.65625 -1.0q0.84375 -0.71875 1.9375 -0.71875zm-1.671875 3.171875q0 2.3125 1.59375 2.3125q0.859375 0 1.203125 -0.703125q0.328125 -0.625 0.328125 -1.53125q0 -0.8125 -0.28125 -1.484375q-0.171875 -0.375 -0.5 -0.59375q-0.3125 -0.21875 -0.828125 -0.21875q-0.5 0 -0.875 0.34375q-0.375 0.328125 -0.515625 0.8125q-0.125 0.46875 -0.125 1.0625zm5.2257385 2.859375l0 -5.359375q0 -0.203125 -0.234375 -0.21875l-0.328125 -0.015625l0 -0.78125l1.640625 0l0 2.21875q0.421875 -0.390625 1.125 -0.390625q0.890625 0 1.390625 0.59375q0.515625 0.59375 0.515625 1.65625q0 2.703125 -2.25 2.703125q-0.484375 0 -1.03125 -0.125q-0.53125 -0.109375 -0.828125 -0.28125zm1.078125 -3.015625l0 2.40625q0.359375 0.1875 0.828125 0.1875q1.0625 0 1.0625 -1.71875q0 -1.578125 -1.046875 -1.578125q-0.34375 0 -0.59375 0.203125q-0.25 0.1875 -0.25 0.5zm5.327057 -1.390625l0 4.90625q0 0.75 -0.234375 1.21875q-0.3125 0.5625 -1.25 0.5625q-0.671875 0 -1.203125 -0.328125l0.34375 -0.78125q0.359375 0.203125 0.65625 0.203125q0.3125 0 0.4375 -0.1875q0.140625 -0.1875 0.140625 -0.640625l0 -3.890625q0 -0.15625 -0.046875 -0.203125q-0.046875 -0.0625 -0.1875 -0.078125l-0.328125 -0.015625l0 -0.765625l1.671875 0zm-1.1875 -0.8125q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.5q0.203125 -0.203125 0.515625 -0.203125q0.328125 0 0.53125 0.203125q0.203125 0.203125 0.203125 0.5q0 0.296875 -0.203125 0.484375q-0.203125 0.1875 -0.53125 0.1875q-0.3125 0 -0.515625 -0.1875zm4.546112 0.671875q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm3.7335815 0.3125q0 0.78125 0.296875 1.21875q0.296875 0.4375 0.859375 0.4375q0.5625 0 1.125 -0.390625l0.390625 0.71875q-0.65625 0.53125 -1.65625 0.53125q-1.0 0 -1.578125 -0.625q-0.5625 -0.625 -0.5625 -1.859375q0 -1.234375 0.65625 -1.84375q0.65625 -0.625 1.484375 -0.625q0.828125 0 1.53125 0.375l0 1.21875l-0.859375 0.0625l0 -0.453125q0 -0.25 -0.1875 -0.3125q-0.171875 -0.0625 -0.375 -0.0625q-1.125 0 -1.125 1.609375zm4.2998047 -3.296875l0.609375 0l0 1.0l1.234375 0l-0.09375 0.78125l-1.140625 0l0 2.453125q0 0.40625 0.140625 0.578125q0.15625 0.15625 0.46875 0.15625q0.3125 0 0.625 -0.203125l0.28125 0.71875q-0.46875 0.328125 -1.1875 0.328125q-0.421875 0 -0.71875 -0.109375q-0.28125 -0.09375 -0.421875 -0.21875q-0.140625 -0.140625 -0.21875 -0.390625q-0.0625 -0.25 -0.078125 -0.390625q0 -0.15625 0 -0.4375l0 -2.484375l-0.671875 0l0.09375 -0.6875q0.4375 -0.03125 0.65625 -0.296875q0.234375 -0.28125 0.421875 -0.796875zm2.485443 5.6875l0 -0.75l0.3125 -0.03125q0.265625 -0.03125 0.265625 -0.265625l0 -4.328125l-0.53125 -0.03125l0 -0.796875l2.5625 0q1.375 0 2.140625 0.734375q0.765625 0.71875 0.765625 2.140625q0 0.890625 -0.234375 1.546875q-0.234375 0.65625 -0.625 1.03125q-0.8125 0.75 -2.0 0.75l-2.65625 0zm1.71875 -5.296875l0 4.40625l0.953125 0q0.796875 0 1.25 -0.578125q0.453125 -0.578125 0.453125 -1.6875q0 -2.140625 -1.828125 -2.140625l-0.828125 0zm6.828186 0.46875q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm3.6398315 -2.984375l0.609375 0l0 1.0l1.234375 0l-0.09375 0.78125l-1.140625 0l0 2.453125q0 0.40625 0.140625 0.578125q0.15625 0.15625 0.46875 0.15625q0.3125 0 0.625 -0.203125l0.28125 0.71875q-0.46875 0.328125 -1.1875 0.328125q-0.421875 0 -0.71875 -0.109375q-0.28125 -0.09375 -0.421875 -0.21875q-0.140625 -0.140625 -0.21875 -0.390625q-0.0625 -0.25 -0.078125 -0.390625q0 -0.15625 0 -0.4375l0 -2.484375l-0.671875 0l0.09375 -0.6875q0.4375 -0.03125 0.65625 -0.296875q0.234375 -0.28125 0.421875 -0.796875zm4.8292236 0.859375q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm3.7335815 0.3125q0 0.78125 0.296875 1.21875q0.296875 0.4375 0.859375 0.4375q0.5625 0 1.125 -0.390625l0.390625 0.71875q-0.65625 0.53125 -1.65625 0.53125q-1.0 0 -1.578125 -0.625q-0.5625 -0.625 -0.5625 -1.859375q0 -1.234375 0.65625 -1.84375q0.65625 -0.625 1.484375 -0.625q0.828125 0 1.53125 0.375l0 1.21875l-0.859375 0.0625l0 -0.453125q0 -0.25 -0.1875 -0.3125q-0.171875 -0.0625 -0.375 -0.0625q-1.125 0 -1.125 1.609375zm4.2998047 -3.296875l0.609375 0l0 1.0l1.234375 0l-0.09375 0.78125l-1.140625 0l0 2.453125q0 0.40625 0.140625 0.578125q0.15625 0.15625 0.46875 0.15625q0.3125 0 0.625 -0.203125l0.28125 0.71875q-0.46875 0.328125 -1.1875 0.328125q-0.421875 0 -0.71875 -0.109375q-0.28125 -0.09375 -0.421875 -0.21875q-0.140625 -0.140625 -0.21875 -0.390625q-0.0625 -0.25 -0.078125 -0.390625q0 -0.15625 0 -0.4375l0 -2.484375l-0.671875 0l0.09375 -0.6875q0.4375 -0.03125 0.65625 -0.296875q0.234375 -0.28125 0.421875 -0.796875zm4.439392 5.0q1.015625 0 1.015625 -1.640625q0 -0.828125 -0.234375 -1.25q-0.21875 -0.4375 -0.765625 -0.4375q-0.53125 0 -0.78125 0.421875q-0.25 0.40625 -0.25 1.109375q0 1.3125 0.484375 1.65625q0.21875 0.140625 0.53125 0.140625zm-2.140625 -1.65625q0 -0.703125 0.203125 -1.203125q0.203125 -0.515625 0.546875 -0.78125q0.65625 -0.5 1.46875 -0.5q0.5625 0 0.9375 0.1875q0.390625 0.171875 0.609375 0.40625q0.21875 0.234375 0.375 0.6875q0.15625 0.453125 0.15625 1.078125q0 1.28125 -0.625 1.953125q-0.625 0.65625 -1.609375 0.65625q-0.96875 0 -1.515625 -0.625q-0.546875 -0.640625 -0.546875 -1.859375zm5.503357 1.296875l0 -2.578125q0 -0.15625 -0.0625 -0.21875q-0.046875 -0.0625 -0.171875 -0.0625l-0.390625 -0.03125l0 -0.75l1.578125 0l0 0.6875q0.15625 -0.359375 0.515625 -0.59375q0.359375 -0.234375 0.8125 -0.234375q0.46875 0 0.890625 0.203125l0 1.390625l-0.875 0.0625l0 -0.421875q0 -0.1875 -0.09375 -0.234375q-0.09375 -0.046875 -0.25 -0.046875q-0.375 0 -0.609375 0.265625q-0.234375 0.265625 -0.234375 0.671875l0 2.15625l0.8125 0.046875l0 0.734375l-2.5 0l0 -0.703125l0.296875 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.078125 -0.0625 0.078125 -0.234375z" fill-rule="nonzero"/><path fill="#ff9900" d="m122.40648 123.44314l0 0c0 -2.968773 2.4066696 -5.375435 5.375435 -5.375435l127.73733 0c1.4256439 0 2.792923 0.5663376 3.8010101 1.5744247c1.0080872 1.0080948 1.5744324 2.3753586 1.5744324 3.8010101l0 21.501091c0 2.9687805 -2.4066772 5.3754425 -5.3754425 5.3754425l-127.73733 0c-2.9687653 0 -5.375435 -2.406662 -5.375435 -5.3754425z" fill-rule="evenodd"/><path fill="#000000" d="m137.1603 130.83994l3.0 0q2.296875 0 2.296875 1.78125q0 0.984375 -0.984375 1.5q0.640625 0.15625 0.984375 0.59375q0.359375 0.421875 0.359375 1.09375q0 1.03125 -0.703125 1.578125q-0.6875 0.546875 -2.015625 0.546875l-3.0 0l0 -0.859375l0.359375 -0.03125q0.296875 -0.03125 0.296875 -0.296875l0 -4.96875l-0.59375 -0.03125l0 -0.90625zm1.90625 3.859375l0 2.234375l1.0 0q1.375 0 1.375 -1.125q0 -0.546875 -0.34375 -0.828125q-0.34375 -0.28125 -0.984375 -0.28125l-1.046875 0zm0 -2.875l0 1.9375l0.78125 0q0.65625 0 0.96875 -0.265625q0.328125 -0.28125 0.328125 -0.765625q0 -0.5 -0.328125 -0.703125q-0.3125 -0.203125 -0.890625 -0.203125l-0.859375 0zm7.1182556 0.59375q0.796875 0 1.265625 0.390625q0.46875 0.390625 0.46875 1.09375q0 0.46875 -0.203125 0.828125q-0.203125 0.34375 -0.5 0.546875q-0.296875 0.203125 -0.734375 0.328125q-0.703125 0.21875 -1.609375 0.21875q0.046875 0.5625 0.359375 0.90625q0.3125 0.34375 0.96875 0.34375q0.671875 0 1.328125 -0.46875l0.40625 0.875q-0.203125 0.1875 -0.71875 0.390625q-0.5 0.203125 -1.15625 0.203125q-1.296875 0 -1.90625 -0.71875q-0.609375 -0.71875 -0.609375 -1.96875q0 -1.265625 0.6875 -2.109375q0.703125 -0.859375 1.953125 -0.859375zm-0.484375 2.4375q0.390625 -0.078125 0.71875 -0.3125q0.328125 -0.25 0.328125 -0.578125q0 -0.640625 -0.640625 -0.640625q-0.59375 0 -0.921875 0.46875q-0.3125 0.46875 -0.34375 1.140625q0.46875 -0.015625 0.859375 -0.078125zm3.6264343 1.890625l0 -2.953125q0 -0.171875 -0.0625 -0.234375q-0.0625 -0.078125 -0.203125 -0.09375l-0.453125 -0.03125l0 -0.859375l1.796875 0l0 0.796875q0.203125 -0.40625 0.59375 -0.671875q0.40625 -0.28125 0.9375 -0.28125q0.53125 0 1.015625 0.234375l0 1.578125l-1.0 0.078125l0 -0.484375q0 -0.21875 -0.109375 -0.265625q-0.109375 -0.0625 -0.28125 -0.0625q-0.421875 0 -0.703125 0.3125q-0.265625 0.296875 -0.265625 0.765625l0 2.46875l0.921875 0.046875l0 0.84375l-2.84375 0l0 -0.8125l0.34375 -0.03125q0.15625 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625zm5.438858 -5.296875l0.703125 0l0 1.125l1.40625 0l-0.109375 0.890625l-1.296875 0l0 2.8125q0 0.46875 0.15625 0.65625q0.171875 0.171875 0.53125 0.171875q0.359375 0 0.71875 -0.21875l0.328125 0.828125q-0.53125 0.375 -1.359375 0.375q-0.484375 0 -0.8125 -0.125q-0.328125 -0.125 -0.484375 -0.265625q-0.15625 -0.15625 -0.25 -0.4375q-0.078125 -0.296875 -0.09375 -0.453125q-0.015625 -0.171875 -0.015625 -0.5l0 -2.84375l-0.75 0l0.109375 -0.78125q0.5 -0.03125 0.75 -0.34375q0.265625 -0.328125 0.46875 -0.890625zm9.126816 2.953125q0 1.375 -0.578125 2.265625q-0.578125 0.875 -1.6875 1.171875l0 0.03125q0.828125 0.046875 1.578125 0.484375q0.53125 0.328125 1.234375 0.328125q0.171875 0 0.328125 -0.015625l-0.125 1.03125q-0.140625 0 -0.28125 0l-1.0625 -0.171875q-0.671875 -0.234375 -1.3125 -0.53125q-0.71875 -0.34375 -1.328125 -0.34375q-0.3125 0 -0.609375 0.0625l0.15625 -0.765625q-2.53125 -0.328125 -2.53125 -3.546875q0 -0.96875 0.28125 -1.703125q0.28125 -0.75 0.765625 -1.15625q0.9375 -0.828125 2.203125 -0.828125q1.265625 0 2.109375 0.90625q0.859375 0.890625 0.859375 2.78125zm-4.890625 -0.046875q0 2.640625 1.828125 2.640625q0.984375 0 1.390625 -0.8125q0.34375 -0.71875 0.34375 -1.75q0 -0.921875 -0.328125 -1.703125q-0.171875 -0.421875 -0.546875 -0.65625q-0.359375 -0.25 -0.953125 -0.25q-0.578125 0 -1.0 0.390625q-0.421875 0.375 -0.578125 0.921875q-0.15625 0.53125 -0.15625 1.21875zm11.325821 2.734375l0 0.84375l-1.734375 0l0 -0.6875q-0.578125 0.828125 -1.71875 0.828125q-1.65625 0 -1.65625 -1.84375l0 -2.46875q0 -0.28125 -0.265625 -0.296875l-0.34375 -0.015625l0 -0.875l1.875 0l0 3.375q0 0.5625 0.15625 0.859375q0.171875 0.28125 0.671875 0.28125q0.5 0 0.796875 -0.296875q0.3125 -0.3125 0.3125 -0.75l0 -2.25q0 -0.171875 -0.0625 -0.234375q-0.0625 -0.078125 -0.203125 -0.09375l-0.34375 -0.015625l0 -0.875l1.875 0l0 4.171875q0 0.171875 0.046875 0.234375q0.0625 0.0625 0.21875 0.078125l0.375 0.03125zm3.3291626 -4.671875q0.796875 0 1.265625 0.390625q0.46875 0.390625 0.46875 1.09375q0 0.46875 -0.203125 0.828125q-0.203125 0.34375 -0.5 0.546875q-0.296875 0.203125 -0.734375 0.328125q-0.703125 0.21875 -1.609375 0.21875q0.046875 0.5625 0.359375 0.90625q0.3125 0.34375 0.96875 0.34375q0.671875 0 1.328125 -0.46875l0.40625 0.875q-0.203125 0.1875 -0.71875 0.390625q-0.5 0.203125 -1.15625 0.203125q-1.296875 0 -1.90625 -0.71875q-0.609375 -0.71875 -0.609375 -1.96875q0 -1.265625 0.6875 -2.109375q0.703125 -0.859375 1.953125 -0.859375zm-0.484375 2.4375q0.390625 -0.078125 0.71875 -0.3125q0.328125 -0.25 0.328125 -0.578125q0 -0.640625 -0.640625 -0.640625q-0.59375 0 -0.921875 0.46875q-0.3125 0.46875 -0.34375 1.140625q0.46875 -0.015625 0.859375 -0.078125zm5.7514343 -1.46875q-0.25 -0.09375 -0.625 -0.09375q-0.359375 0 -0.578125 0.171875q-0.21875 0.15625 -0.21875 0.390625q0 0.234375 0.078125 0.375q0.09375 0.140625 0.265625 0.234375q0.265625 0.140625 0.625 0.234375q0.375 0.09375 0.5625 0.15625q0.1875 0.0625 0.453125 0.203125q0.265625 0.140625 0.40625 0.296875q0.375 0.390625 0.375 1.015625q0 0.796875 -0.578125 1.25q-0.578125 0.453125 -1.484375 0.453125q-1.296875 0 -1.953125 -0.328125l0 -1.484375l0.953125 -0.078125l0 0.515625q0 0.46875 0.890625 0.46875q0.90625 0 0.90625 -0.65625q0 -0.234375 -0.15625 -0.375q-0.15625 -0.15625 -0.3125 -0.203125q-0.140625 -0.0625 -0.359375 -0.109375q-0.203125 -0.046875 -0.40625 -0.09375q-0.1875 -0.0625 -0.421875 -0.15625q-0.21875 -0.09375 -0.5 -0.265625q-0.546875 -0.34375 -0.546875 -1.171875q0 -0.828125 0.578125 -1.265625q0.59375 -0.453125 1.484375 -0.453125q0.890625 0 1.765625 0.421875l0 1.28125l-0.953125 0.078125l0 -0.453125q0 -0.265625 -0.25 -0.359375zm3.0899658 -1.9375l0.703125 0l0 1.125l1.40625 0l-0.109375 0.890625l-1.296875 0l0 2.8125q0 0.46875 0.15625 0.65625q0.171875 0.171875 0.53125 0.171875q0.359375 0 0.71875 -0.21875l0.328125 0.828125q-0.53125 0.375 -1.359375 0.375q-0.484375 0 -0.8125 -0.125q-0.328125 -0.125 -0.484375 -0.265625q-0.15625 -0.15625 -0.25 -0.4375q-0.078125 -0.296875 -0.09375 -0.453125q-0.015625 -0.171875 -0.015625 -0.5l0 -2.84375l-0.75 0l0.109375 -0.78125q0.5 -0.03125 0.75 -0.34375q0.265625 -0.328125 0.46875 -0.890625zm4.767441 1.125l0 4.46875l0.609375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.34375 -0.03125q0.3125 -0.03125 0.3125 -0.34375l0 -2.953125q0 -0.1875 -0.0625 -0.25q-0.0625 -0.0625 -0.203125 -0.0625l-0.375 -0.015625l0 -0.890625l1.90625 0zm-1.34375 -0.921875q-0.21875 -0.21875 -0.21875 -0.5625q0 -0.34375 0.21875 -0.5625q0.21875 -0.234375 0.578125 -0.234375q0.375 0 0.59375 0.234375q0.234375 0.21875 0.234375 0.5625q0 0.34375 -0.234375 0.5625q-0.21875 0.21875 -0.59375 0.21875q-0.359375 0 -0.578125 -0.21875zm5.0136566 5.5q1.171875 0 1.171875 -1.875q0 -0.953125 -0.265625 -1.4375q-0.265625 -0.484375 -0.875 -0.484375q-0.609375 0 -0.90625 0.46875q-0.28125 0.46875 -0.28125 1.28125q0 1.5 0.546875 1.875q0.25 0.171875 0.609375 0.171875zm-2.4375 -1.90625q0 -0.78125 0.234375 -1.359375q0.234375 -0.59375 0.625 -0.890625q0.75 -0.578125 1.671875 -0.578125q0.640625 0 1.078125 0.203125q0.4375 0.203125 0.6875 0.484375q0.25 0.265625 0.421875 0.78125q0.1875 0.515625 0.1875 1.21875q0 1.46875 -0.71875 2.234375q-0.703125 0.75 -1.828125 0.75q-1.109375 0 -1.734375 -0.71875q-0.625 -0.71875 -0.625 -2.125zm5.5484314 -1.8125l0 -0.859375l1.8125 0l0 0.6875q0.265625 -0.390625 0.71875 -0.609375q0.46875 -0.234375 1.0 -0.234375q0.8125 0 1.25 0.46875q0.453125 0.453125 0.453125 1.375l0 2.78125l0.59375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.359375 -0.03125q0.15625 -0.015625 0.21875 -0.078125q0.078125 -0.078125 0.078125 -0.265625l0 -2.203125q0 -0.578125 -0.1875 -0.859375q-0.171875 -0.296875 -0.6875 -0.296875q-0.5 0 -0.796875 0.328125q-0.296875 0.3125 -0.296875 0.75l0 2.578125l0.609375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.34375 -0.03125q0.171875 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625l0 -2.953125q0 -0.328125 -0.265625 -0.328125l-0.453125 -0.03125zm6.831314 3.3125l1.71875 -4.96875l-0.671875 -0.0625l0 -0.875l2.4375 0l2.1875 6.15625l0 0l0.609375 0.03125l0 0.90625l-2.75 0l0 -0.859375l0.40625 -0.03125q0.1875 -0.03125 0.25 -0.09375q0.0625 -0.0625 0 -0.234375l-0.21875 -0.65625l-2.5 0l-0.3125 0.9375l0.640625 0.03125l0 0.90625l-2.46875 0l0 -0.859375l0.359375 -0.03125q0.21875 -0.03125 0.3125 -0.296875zm2.71875 -4.703125l-0.984375 3.03125l1.984375 0l-0.96875 -3.03125l-0.03125 0zm3.8982544 1.390625l0 -0.859375l1.8125 0l0 0.6875q0.265625 -0.390625 0.71875 -0.609375q0.46875 -0.234375 1.0 -0.234375q0.8125 0 1.25 0.46875q0.453125 0.453125 0.453125 1.375l0 2.78125l0.59375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.359375 -0.03125q0.15625 -0.015625 0.21875 -0.078125q0.078125 -0.078125 0.078125 -0.265625l0 -2.203125q0 -0.578125 -0.1875 -0.859375q-0.171875 -0.296875 -0.6875 -0.296875q-0.5 0 -0.796875 0.328125q-0.296875 0.3125 -0.296875 0.75l0 2.578125l0.609375 0.03125l0 0.859375l-2.53125 0l0 -0.8125l0.34375 -0.03125q0.171875 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625l0 -2.953125q0 -0.328125 -0.265625 -0.328125l-0.453125 -0.03125zm9.253189 -0.046875q-0.25 -0.09375 -0.625 -0.09375q-0.359375 0 -0.578125 0.171875q-0.21875 0.15625 -0.21875 0.390625q0 0.234375 0.078125 0.375q0.09375 0.140625 0.265625 0.234375q0.265625 0.140625 0.625 0.234375q0.375 0.09375 0.5625 0.15625q0.1875 0.0625 0.453125 0.203125q0.265625 0.140625 0.40625 0.296875q0.375 0.390625 0.375 1.015625q0 0.796875 -0.578125 1.25q-0.578125 0.453125 -1.484375 0.453125q-1.296875 0 -1.953125 -0.328125l0 -1.484375l0.953125 -0.078125l0 0.515625q0 0.46875 0.890625 0.46875q0.90625 0 0.90625 -0.65625q0 -0.234375 -0.15625 -0.375q-0.15625 -0.15625 -0.3125 -0.203125q-0.140625 -0.0625 -0.359375 -0.109375q-0.203125 -0.046875 -0.40625 -0.09375q-0.1875 -0.0625 -0.421875 -0.15625q-0.21875 -0.09375 -0.5 -0.265625q-0.546875 -0.34375 -0.546875 -1.171875q0 -0.828125 0.578125 -1.265625q0.59375 -0.453125 1.484375 -0.453125q0.890625 0 1.765625 0.421875l0 1.28125l-0.953125 0.078125l0 -0.453125q0 -0.265625 -0.25 -0.359375zm5.54039 -0.59375l1.03125 0l1.109375 3.953125l0.03125 0l0.71875 -3.28125l-0.546875 -0.03125l0 -0.859375l2.140625 0l0 0.8125l-0.34375 0.03125q-0.21875 0.015625 -0.296875 0.25l-1.0 4.265625l-1.546875 0l-0.859375 -3.171875l-0.046875 0l-0.890625 3.171875l-1.578125 0l-1.078125 -4.25q-0.046875 -0.15625 -0.109375 -0.203125q-0.0625 -0.046875 -0.1875 -0.0625l-0.3125 -0.03125l0 -0.8125l2.4375 0l0 0.859375l-0.59375 0.03125l0.734375 3.265625l0.015625 0l1.171875 -3.9375zm7.6280518 -0.375q0.796875 0 1.265625 0.390625q0.46875 0.390625 0.46875 1.09375q0 0.46875 -0.203125 0.828125q-0.203125 0.34375 -0.5 0.546875q-0.296875 0.203125 -0.734375 0.328125q-0.703125 0.21875 -1.609375 0.21875q0.046875 0.5625 0.359375 0.90625q0.3125 0.34375 0.96875 0.34375q0.671875 0 1.328125 -0.46875l0.40625 0.875q-0.203125 0.1875 -0.71875 0.390625q-0.5 0.203125 -1.15625 0.203125q-1.296875 0 -1.90625 -0.71875q-0.609375 -0.71875 -0.609375 -1.96875q0 -1.265625 0.6875 -2.109375q0.703125 -0.859375 1.953125 -0.859375zm-0.484375 2.4375q0.390625 -0.078125 0.71875 -0.3125q0.328125 -0.25 0.328125 -0.578125q0 -0.640625 -0.640625 -0.640625q-0.59375 0 -0.921875 0.46875q-0.3125 0.46875 -0.34375 1.140625q0.46875 -0.015625 0.859375 -0.078125zm3.6264343 1.890625l0 -2.953125q0 -0.171875 -0.0625 -0.234375q-0.0625 -0.078125 -0.203125 -0.09375l-0.453125 -0.03125l0 -0.859375l1.796875 0l0 0.796875q0.203125 -0.40625 0.59375 -0.671875q0.40625 -0.28125 0.9375 -0.28125q0.53125 0 1.015625 0.234375l0 1.578125l-1.0 0.078125l0 -0.484375q0 -0.21875 -0.109375 -0.265625q-0.109375 -0.0625 -0.28125 -0.0625q-0.421875 0 -0.703125 0.3125q-0.265625 0.296875 -0.265625 0.765625l0 2.46875l0.921875 0.046875l0 0.84375l-2.84375 0l0 -0.8125l0.34375 -0.03125q0.15625 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625zm6.938858 -4.328125q0.796875 0 1.265625 0.390625q0.46875 0.390625 0.46875 1.09375q0 0.46875 -0.203125 0.828125q-0.203125 0.34375 -0.5 0.546875q-0.296875 0.203125 -0.734375 0.328125q-0.703125 0.21875 -1.609375 0.21875q0.046875 0.5625 0.359375 0.90625q0.3125 0.34375 0.96875 0.34375q0.671875 0 1.328125 -0.46875l0.40625 0.875q-0.203125 0.1875 -0.71875 0.390625q-0.5 0.203125 -1.15625 0.203125q-1.296875 0 -1.90625 -0.71875q-0.609375 -0.71875 -0.609375 -1.96875q0 -1.265625 0.6875 -2.109375q0.703125 -0.859375 1.953125 -0.859375zm-0.484375 2.4375q0.390625 -0.078125 0.71875 -0.3125q0.328125 -0.25 0.328125 -0.578125q0 -0.640625 -0.640625 -0.640625q-0.59375 0 -0.921875 0.46875q-0.3125 0.46875 -0.34375 1.140625q0.46875 -0.015625 0.859375 -0.078125zm3.6264343 1.890625l0 -2.953125q0 -0.171875 -0.0625 -0.234375q-0.0625 -0.078125 -0.203125 -0.09375l-0.453125 -0.03125l0 -0.859375l1.796875 0l0 0.796875q0.203125 -0.40625 0.59375 -0.671875q0.40625 -0.28125 0.9375 -0.28125q0.53125 0 1.015625 0.234375l0 1.578125l-1.0 0.078125l0 -0.484375q0 -0.21875 -0.109375 -0.265625q-0.109375 -0.0625 -0.28125 -0.0625q-0.421875 0 -0.703125 0.3125q-0.265625 0.296875 -0.265625 0.765625l0 2.46875l0.921875 0.046875l0 0.84375l-2.84375 0l0 -0.8125l0.34375 -0.03125q0.15625 -0.015625 0.234375 -0.078125q0.078125 -0.078125 0.078125 -0.265625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m191.64935 35.244095l0 25.29134" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m191.64935 35.244095l0 19.29134" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m189.99763 54.535435l1.6517181 4.5380974l1.6517334 -4.5380974z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m191.64592 92.78188l0 25.291336" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m191.64592 92.78188l0 19.291344" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m189.99419 112.07323l1.6517334 4.5380936l1.6517334 -4.5380936z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m421.7318 92.78451l94.92914 25.291336" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m421.73184 92.78451l89.13135 23.746681" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m510.438 118.12725l4.810364 -0.42775726l-3.9599304 -2.7643585z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m54.16782 92.78188l-0.1889801 25.291336" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m54.16782 92.78188l-0.14414597 19.291504" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m52.371986 112.06105l1.6177788 4.550316l1.6855965 -4.525635z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m191.64935 35.244095l-137.48032 25.29134" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m191.64935 35.244095l-131.57933 24.205776" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m59.771175 57.825397l-4.16436 2.4455376l4.762047 0.80340576z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m191.64935 35.244095l230.07874 25.29134" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m191.64935 35.244095l224.11469 24.635738" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m415.58356 61.521675l4.6914062 -1.1459808l-4.3304443 -2.137703z" fill-rule="evenodd"/><path fill="#ff9900" d="m281.3436 123.44633l0 0c0 -2.968773 2.4066772 -5.375435 5.3754272 -5.375435l78.225525 0c1.4256592 0 2.7929077 0.5663376 3.8009949 1.5744324c1.0080872 1.0080872 1.5744324 2.375351 1.5744324 3.8010025l0 21.501106c0 2.9687653 -2.4066467 5.3754272 -5.3754272 5.3754272l-78.225525 0c-2.96875 0 -5.3754272 -2.406662 -5.3754272 -5.3754272z" fill-rule="evenodd"/><path fill="#000000" d="m293.84845 136.56313l0 -4.28125l-0.53125 -0.03125l0 -0.796875l2.265625 0l0 0.75l-0.3125 0.03125q-0.265625 0.03125 -0.265625 0.296875l0 4.296875l0.53125 0.03125l0 0.796875l-2.265625 0l0 -0.75l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.234375zm2.238739 -2.84375l0 -0.75l1.578125 0l0 0.609375q0.515625 -0.75 1.421875 -0.75q0.921875 0 1.234375 0.75q0.53125 -0.75 1.421875 -0.75q0.671875 0 1.03125 0.40625q0.359375 0.40625 0.359375 1.21875l0 2.421875l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.21875l0 -1.9375q0 -0.515625 -0.125 -0.75q-0.109375 -0.25 -0.53125 -0.25q-0.421875 0 -0.65625 0.265625q-0.234375 0.25 -0.234375 0.6875l0 2.234375l0.515625 0.03125l0 0.75l-2.203125 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.234375l0 -1.921875q0 -0.515625 -0.125 -0.75q-0.109375 -0.25 -0.546875 -0.25q-0.421875 0 -0.65625 0.265625q-0.21875 0.265625 -0.21875 0.671875l0 2.25l0.5 0.03125l0 0.75l-2.1875 0l0 -0.703125l0.296875 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.078125 -0.0625 0.078125 -0.234375l0 0l0 -2.578125q0 -0.15625 -0.046875 -0.21875q-0.046875 -0.0625 -0.1875 -0.0625l-0.390625 -0.03125zm12.3281555 -0.5625l0 3.46875q0 0.140625 0.046875 0.203125q0.046875 0.046875 0.1875 0.0625l0.296875 0.015625l0 0.75l-1.484375 0l0 -0.546875l-0.03125 0q-0.46875 0.671875 -1.28125 0.671875q-0.953125 0 -1.40625 -0.609375q-0.453125 -0.609375 -0.453125 -1.65625q0 -1.28125 0.609375 -1.984375q0.625 -0.703125 1.84375 -0.703125q0.796875 0 1.671875 0.328125zm-1.078125 3.015625l0 -2.421875q-0.265625 -0.109375 -0.71875 -0.109375q-0.625 0 -0.90625 0.5q-0.28125 0.5 -0.28125 1.328125q0 1.515625 0.96875 1.515625q0.40625 0 0.671875 -0.25q0.265625 -0.25 0.265625 -0.5625zm6.3203125 1.46875q-0.078125 1.046875 -0.59375 1.53125q-0.515625 0.484375 -1.609375 0.484375q-1.078125 0 -1.703125 -0.3125l0 -1.296875l0.890625 -0.078125l0 0.453125q0 0.25 0.234375 0.328125q0.234375 0.09375 0.640625 0.09375q0.59375 0 0.828125 -0.359375q0.25 -0.359375 0.25 -1.0l0 -0.515625q-0.4375 0.515625 -1.234375 0.515625q-0.484375 0 -0.859375 -0.1875q-0.359375 -0.1875 -0.5625 -0.5q-0.40625 -0.640625 -0.40625 -1.484375q0 -1.234375 0.65625 -1.859375q0.671875 -0.625 1.765625 -0.625q0.921875 0 1.734375 0.34375l0 3.453125q0 0.796875 -0.03125 1.015625zm-1.0625 -1.890625l0 -2.0q-0.328125 -0.109375 -0.734375 -0.109375q-0.578125 0 -0.890625 0.421875q-0.296875 0.40625 -0.296875 1.125q0 1.484375 0.96875 1.484375q0.40625 0 0.671875 -0.25q0.28125 -0.25 0.28125 -0.671875zm4.3401184 -2.921875q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm6.3585815 -2.515625q-0.09375 -0.171875 -0.765625 -0.171875q-0.859375 0 -1.3125 0.59375q-0.4375 0.578125 -0.4375 1.6875q0 2.3125 1.71875 2.3125q0.03125 0 0.25 0q0.21875 0 0.40625 -0.0625q0.203125 -0.0625 0.234375 -0.125q0.046875 -0.0625 0.046875 -0.203125l0 -0.75l0.953125 0.0625l0 1.609375q-0.734375 0.390625 -1.859375 0.390625q-1.421875 0 -2.171875 -0.796875q-0.75 -0.8125 -0.75 -2.390625q0 -0.890625 0.25 -1.53125q0.25 -0.65625 0.6875 -1.015625q0.828125 -0.703125 1.984375 -0.703125q0.9375 0 1.765625 0.375l0 1.59375l-0.953125 0.0625l0 -0.734375q0 -0.140625 -0.046875 -0.203125zm3.5374146 -1.4375l0 5.875l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.296875 -0.03125q0.28125 -0.03125 0.28125 -0.296875l0 -4.609375q0 -0.203125 -0.234375 -0.21875l-0.328125 -0.015625l0 -0.78125l1.671875 0zm5.29541 2.15625l0 3.46875q0 0.140625 0.046875 0.203125q0.046875 0.046875 0.1875 0.0625l0.296875 0.015625l0 0.75l-1.484375 0l0 -0.546875l-0.03125 0q-0.46875 0.671875 -1.28125 0.671875q-0.953125 0 -1.40625 -0.609375q-0.453125 -0.609375 -0.453125 -1.65625q0 -1.28125 0.609375 -1.984375q0.625 -0.703125 1.84375 -0.703125q0.796875 0 1.671875 0.328125zm-1.078125 3.015625l0 -2.421875q-0.265625 -0.109375 -0.71875 -0.109375q-0.625 0 -0.90625 0.5q-0.28125 0.5 -0.28125 1.328125q0 1.515625 0.96875 1.515625q0.40625 0 0.671875 -0.25q0.265625 -0.25 0.265625 -0.5625zm4.617157 -2.5q-0.21875 -0.078125 -0.546875 -0.078125q-0.3125 0 -0.5 0.140625q-0.1875 0.140625 -0.1875 0.359375q0 0.203125 0.0625 0.328125q0.078125 0.109375 0.21875 0.1875q0.234375 0.125 0.5625 0.21875q0.328125 0.078125 0.484375 0.140625q0.15625 0.046875 0.390625 0.171875q0.25 0.125 0.359375 0.265625q0.328125 0.34375 0.328125 0.875q0 0.703125 -0.515625 1.109375q-0.5 0.390625 -1.28125 0.390625q-1.140625 0 -1.71875 -0.296875l0 -1.296875l0.84375 -0.0625l0 0.453125q0 0.40625 0.78125 0.40625q0.78125 0 0.78125 -0.5625q0 -0.21875 -0.140625 -0.34375q-0.125 -0.125 -0.265625 -0.171875q-0.125 -0.046875 -0.3125 -0.09375q-0.171875 -0.046875 -0.34375 -0.09375q-0.171875 -0.046875 -0.375 -0.125q-0.203125 -0.078125 -0.4375 -0.234375q-0.484375 -0.3125 -0.484375 -1.03125q0 -0.71875 0.515625 -1.109375q0.515625 -0.390625 1.296875 -0.390625q0.78125 0 1.546875 0.375l0 1.109375l-0.84375 0.0625l0 -0.390625q0 -0.234375 -0.21875 -0.3125zm4.2070007 0q-0.21875 -0.078125 -0.546875 -0.078125q-0.3125 0 -0.5 0.140625q-0.1875 0.140625 -0.1875 0.359375q0 0.203125 0.0625 0.328125q0.078125 0.109375 0.21875 0.1875q0.234375 0.125 0.5625 0.21875q0.328125 0.078125 0.484375 0.140625q0.15625 0.046875 0.390625 0.171875q0.25 0.125 0.359375 0.265625q0.328125 0.34375 0.328125 0.875q0 0.703125 -0.515625 1.109375q-0.5 0.390625 -1.28125 0.390625q-1.140625 0 -1.71875 -0.296875l0 -1.296875l0.84375 -0.0625l0 0.453125q0 0.40625 0.78125 0.40625q0.78125 0 0.78125 -0.5625q0 -0.21875 -0.140625 -0.34375q-0.125 -0.125 -0.265625 -0.171875q-0.125 -0.046875 -0.3125 -0.09375q-0.171875 -0.046875 -0.34375 -0.09375q-0.171875 -0.046875 -0.375 -0.125q-0.203125 -0.078125 -0.4375 -0.234375q-0.484375 -0.3125 -0.484375 -1.03125q0 -0.71875 0.515625 -1.109375q0.515625 -0.390625 1.296875 -0.390625q0.78125 0 1.546875 0.375l0 1.109375l-0.84375 0.0625l0 -0.390625q0 -0.234375 -0.21875 -0.3125zm3.4569702 -0.703125l0 3.90625l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.296875 -0.03125q0.28125 -0.03125 0.28125 -0.296875l0 -2.59375q0 -0.15625 -0.0625 -0.21875q-0.046875 -0.0625 -0.171875 -0.0625l-0.328125 -0.015625l0 -0.765625l1.671875 0zm-1.1875 -0.8125q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.5q0.203125 -0.203125 0.515625 -0.203125q0.328125 0 0.515625 0.203125q0.203125 0.203125 0.203125 0.5q0 0.296875 -0.203125 0.484375q-0.1875 0.1875 -0.515625 0.1875q-0.3125 0 -0.515625 -0.1875zm2.839264 4.453125l0 -2.859375l-0.796875 0l0 -0.78125l0.796875 0l0 -0.359375q0 -0.90625 0.421875 -1.3125q0.390625 -0.390625 1.140625 -0.390625q0.765625 0 1.359375 0.328125l-0.28125 0.75q-0.4375 -0.234375 -0.84375 -0.234375q-0.390625 0 -0.546875 0.203125q-0.140625 0.1875 -0.140625 0.578125l0 0.4375l1.390625 0l0 0.78125l-1.390625 0l0 3.125l0.796875 0.03125l0 0.75l-2.484375 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.234375zm4.5607605 -3.640625l0 3.90625l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.296875 -0.03125q0.28125 -0.03125 0.28125 -0.296875l0 -2.59375q0 -0.15625 -0.0625 -0.21875q-0.046875 -0.0625 -0.171875 -0.0625l-0.328125 -0.015625l0 -0.765625l1.671875 0zm-1.1875 -0.8125q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.5q0.203125 -0.203125 0.515625 -0.203125q0.328125 0 0.515625 0.203125q0.203125 0.203125 0.203125 0.5q0 0.296875 -0.203125 0.484375q-0.1875 0.1875 -0.515625 0.1875q-0.3125 0 -0.515625 -0.1875zm4.604889 0.671875q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm3.186676 1.65625l0 -2.578125q0 -0.15625 -0.0625 -0.21875q-0.046875 -0.0625 -0.171875 -0.0625l-0.390625 -0.03125l0 -0.75l1.578125 0l0 0.6875q0.15625 -0.359375 0.515625 -0.59375q0.359375 -0.234375 0.8125 -0.234375q0.46875 0 0.890625 0.203125l0 1.390625l-0.875 0.0625l0 -0.421875q0 -0.1875 -0.09375 -0.234375q-0.09375 -0.046875 -0.25 -0.046875q-0.375 0 -0.609375 0.265625q-0.234375 0.265625 -0.234375 0.671875l0 2.15625l0.8125 0.046875l0 0.734375l-2.5 0l0 -0.703125l0.296875 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.078125 -0.0625 0.078125 -0.234375z" fill-rule="nonzero"/><path fill="#ff9900" d="m374.3386 123.446304l0 0c0 -2.968773 2.4066467 -5.375435 5.3754272 -5.375435l84.80817 0c1.4256592 0 2.7929382 0.5663376 3.8010254 1.5744247c1.0080872 1.0080872 1.5744324 2.375351 1.5744324 3.8010101l0 21.501099c0 2.9687653 -2.4066772 5.3754272 -5.375458 5.3754272l-84.80817 0c-2.9687805 0 -5.3754272 -2.406662 -5.3754272 -5.3754272z" fill-rule="evenodd"/><path fill="#000000" d="m387.77484 136.5631l0 -4.28125l-0.53125 -0.03125l0 -0.796875l2.265625 0l0 0.75l-0.3125 0.03125q-0.265625 0.03125 -0.265625 0.296875l0 4.296875l0.53125 0.03125l0 0.796875l-2.265625 0l0 -0.75l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.234375zm2.238739 -2.84375l0 -0.75l1.578125 0l0 0.609375q0.515625 -0.75 1.421875 -0.75q0.921875 0 1.234375 0.75q0.53125 -0.75 1.421875 -0.75q0.671875 0 1.03125 0.40625q0.359375 0.40625 0.359375 1.21875l0 2.421875l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.21875l0 -1.9375q0 -0.515625 -0.125 -0.75q-0.109375 -0.25 -0.53125 -0.25q-0.421875 0 -0.65625 0.265625q-0.234375 0.25 -0.234375 0.6875l0 2.234375l0.515625 0.03125l0 0.75l-2.203125 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.234375l0 -1.921875q0 -0.515625 -0.125 -0.75q-0.109375 -0.25 -0.546875 -0.25q-0.421875 0 -0.65625 0.265625q-0.21875 0.265625 -0.21875 0.671875l0 2.25l0.5 0.03125l0 0.75l-2.1875 0l0 -0.703125l0.296875 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.078125 -0.0625 0.078125 -0.234375l0 0l0 -2.578125q0 -0.15625 -0.046875 -0.21875q-0.046875 -0.0625 -0.1875 -0.0625l-0.390625 -0.03125zm12.3281555 -0.5625l0 3.46875q0 0.140625 0.046875 0.203125q0.046875 0.046875 0.1875 0.0625l0.296875 0.015625l0 0.75l-1.484375 0l0 -0.546875l-0.03125 0q-0.46875 0.671875 -1.28125 0.671875q-0.953125 0 -1.40625 -0.609375q-0.453125 -0.609375 -0.453125 -1.65625q0 -1.28125 0.609375 -1.984375q0.625 -0.703125 1.84375 -0.703125q0.796875 0 1.671875 0.328125zm-1.078125 3.015625l0 -2.421875q-0.265625 -0.109375 -0.71875 -0.109375q-0.625 0 -0.90625 0.5q-0.28125 0.5 -0.28125 1.328125q0 1.515625 0.96875 1.515625q0.40625 0 0.671875 -0.25q0.265625 -0.25 0.265625 -0.5625zm6.3203125 1.46875q-0.078125 1.046875 -0.59375 1.53125q-0.515625 0.484375 -1.609375 0.484375q-1.078125 0 -1.703125 -0.3125l0 -1.296875l0.890625 -0.078125l0 0.453125q0 0.25 0.234375 0.328125q0.234375 0.09375 0.640625 0.09375q0.59375 0 0.828125 -0.359375q0.25 -0.359375 0.25 -1.0l0 -0.515625q-0.4375 0.515625 -1.234375 0.515625q-0.484375 0 -0.859375 -0.1875q-0.359375 -0.1875 -0.5625 -0.5q-0.40625 -0.640625 -0.40625 -1.484375q0 -1.234375 0.65625 -1.859375q0.671875 -0.625 1.765625 -0.625q0.921875 0 1.734375 0.34375l0 3.453125q0 0.796875 -0.03125 1.015625zm-1.0625 -1.890625l0 -2.0q-0.328125 -0.109375 -0.734375 -0.109375q-0.578125 0 -0.890625 0.421875q-0.296875 0.40625 -0.296875 1.125q0 1.484375 0.96875 1.484375q0.40625 0 0.671875 -0.25q0.28125 -0.25 0.28125 -0.671875zm4.3401184 -2.921875q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm4.5460815 1.859375q1.125 0 1.125 -0.890625q0 -0.46875 -0.5 -0.6875q-0.265625 -0.140625 -0.6875 -0.25q-0.40625 -0.109375 -0.6875 -0.21875q-0.265625 -0.125 -0.546875 -0.328125q-0.546875 -0.40625 -0.546875 -1.25q0 -0.859375 0.578125 -1.34375q0.59375 -0.5 1.546875 -0.5q0.96875 0 1.78125 0.375l0 1.40625l-0.9375 0.0625l0 -0.546875q0 -0.140625 -0.046875 -0.203125q-0.09375 -0.171875 -0.65625 -0.171875q-0.546875 0 -0.828125 0.203125q-0.265625 0.1875 -0.265625 0.59375q0 0.296875 0.234375 0.53125q0.140625 0.140625 0.4375 0.25q0.296875 0.09375 0.625 0.203125q0.328125 0.09375 0.578125 0.21875q0.265625 0.109375 0.546875 0.328125q0.578125 0.421875 0.578125 1.296875q0 0.875 -0.640625 1.390625q-0.625 0.5 -1.65625 0.5q-1.03125 0 -1.859375 -0.390625l0 -1.46875l0.953125 -0.0625l0 0.5625q0 0.140625 0.03125 0.21875q0.046875 0.0625 0.234375 0.125q0.203125 0.046875 0.609375 0.046875zm5.2755127 -3.984375q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm6.7492065 2.6875q-0.078125 1.046875 -0.59375 1.53125q-0.515625 0.484375 -1.609375 0.484375q-1.078125 0 -1.703125 -0.3125l0 -1.296875l0.890625 -0.078125l0 0.453125q0 0.25 0.234375 0.328125q0.234375 0.09375 0.640625 0.09375q0.59375 0 0.828125 -0.359375q0.25 -0.359375 0.25 -1.0l0 -0.515625q-0.4375 0.515625 -1.234375 0.515625q-0.484375 0 -0.859375 -0.1875q-0.359375 -0.1875 -0.5625 -0.5q-0.40625 -0.640625 -0.40625 -1.484375q0 -1.234375 0.65625 -1.859375q0.671875 -0.625 1.765625 -0.625q0.921875 0 1.734375 0.34375l0 3.453125q0 0.796875 -0.03125 1.015625zm-1.0625 -1.890625l0 -2.0q-0.328125 -0.109375 -0.734375 -0.109375q-0.578125 0 -0.890625 0.421875q-0.296875 0.40625 -0.296875 1.125q0 1.484375 0.96875 1.484375q0.40625 0 0.671875 -0.25q0.28125 -0.25 0.28125 -0.671875zm1.9338684 -2.03125l0 -0.75l1.578125 0l0 0.609375q0.515625 -0.75 1.421875 -0.75q0.921875 0 1.234375 0.75q0.53125 -0.75 1.421875 -0.75q0.671875 0 1.03125 0.40625q0.359375 0.40625 0.359375 1.21875l0 2.421875l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.21875l0 -1.9375q0 -0.515625 -0.125 -0.75q-0.109375 -0.25 -0.53125 -0.25q-0.421875 0 -0.65625 0.265625q-0.234375 0.25 -0.234375 0.6875l0 2.234375l0.515625 0.03125l0 0.75l-2.203125 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.234375l0 -1.921875q0 -0.515625 -0.125 -0.75q-0.109375 -0.25 -0.546875 -0.25q-0.421875 0 -0.65625 0.265625q-0.21875 0.265625 -0.21875 0.671875l0 2.25l0.5 0.03125l0 0.75l-2.1875 0l0 -0.703125l0.296875 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.078125 -0.0625 0.078125 -0.234375l0 0l0 -2.578125q0 -0.15625 -0.046875 -0.21875q-0.046875 -0.0625 -0.1875 -0.0625l-0.390625 -0.03125zm10.4844055 -0.890625q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm2.5460815 -1.25l0 -0.734375l1.578125 0l0 0.609375q0.25 -0.359375 0.640625 -0.546875q0.40625 -0.203125 0.875 -0.203125q0.703125 0 1.09375 0.40625q0.390625 0.40625 0.390625 1.21875l0 2.421875l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.296875 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.078125 -0.0625 0.078125 -0.234375l0 -1.921875q0 -0.515625 -0.171875 -0.75q-0.15625 -0.25 -0.59375 -0.25q-0.4375 0 -0.703125 0.28125q-0.265625 0.265625 -0.265625 0.65625l0 2.25l0.53125 0.03125l0 0.75l-2.21875 0l0 -0.703125l0.3125 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.0625 -0.0625 0.0625 -0.234375l0 -2.578125q0 -0.28125 -0.234375 -0.296875l-0.390625 -0.03125zm6.3172913 -1.734375l0.609375 0l0 1.0l1.234375 0l-0.09375 0.78125l-1.140625 0l0 2.453125q0 0.40625 0.140625 0.578125q0.15625 0.15625 0.46875 0.15625q0.3125 0 0.625 -0.203125l0.28125 0.71875q-0.46875 0.328125 -1.1875 0.328125q-0.421875 0 -0.71875 -0.109375q-0.28125 -0.09375 -0.421875 -0.21875q-0.140625 -0.140625 -0.21875 -0.390625q-0.0625 -0.25 -0.078125 -0.390625q0 -0.15625 0 -0.4375l0 -2.484375l-0.671875 0l0.09375 -0.6875q0.4375 -0.03125 0.65625 -0.296875q0.234375 -0.28125 0.421875 -0.796875zm4.829193 0.859375q0.6875 0 1.09375 0.34375q0.421875 0.328125 0.421875 0.953125q0 0.40625 -0.1875 0.71875q-0.171875 0.3125 -0.4375 0.5q-0.25 0.171875 -0.625 0.28125q-0.625 0.1875 -1.40625 0.1875q0.03125 0.484375 0.296875 0.796875q0.28125 0.296875 0.859375 0.296875q0.578125 0 1.15625 -0.40625l0.359375 0.75q-0.1875 0.171875 -0.640625 0.359375q-0.4375 0.171875 -1.0 0.171875q-1.140625 0 -1.671875 -0.625q-0.53125 -0.640625 -0.53125 -1.734375q0 -1.109375 0.609375 -1.84375q0.609375 -0.75 1.703125 -0.75zm-0.4375 2.125q0.34375 -0.0625 0.625 -0.265625q0.296875 -0.21875 0.296875 -0.5q0 -0.578125 -0.5625 -0.578125q-0.515625 0 -0.796875 0.421875q-0.28125 0.421875 -0.3125 1.0q0.421875 -0.015625 0.75 -0.078125zm3.1867065 1.65625l0 -2.578125q0 -0.15625 -0.0625 -0.21875q-0.046875 -0.0625 -0.171875 -0.0625l-0.390625 -0.03125l0 -0.75l1.578125 0l0 0.6875q0.15625 -0.359375 0.515625 -0.59375q0.359375 -0.234375 0.8125 -0.234375q0.46875 0 0.890625 0.203125l0 1.390625l-0.875 0.0625l0 -0.421875q0 -0.1875 -0.09375 -0.234375q-0.09375 -0.046875 -0.25 -0.046875q-0.375 0 -0.609375 0.265625q-0.234375 0.265625 -0.234375 0.671875l0 2.15625l0.8125 0.046875l0 0.734375l-2.5 0l0 -0.703125l0.296875 -0.03125q0.140625 -0.015625 0.203125 -0.078125q0.078125 -0.0625 0.078125 -0.234375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m421.7318 92.78451l0.3779602 25.291336" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m421.73184 92.78451l0.28829956 19.292007" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m420.3686 112.1012l1.7193604 4.5129166l1.5837402 -4.562271z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m421.7318 92.78451l-95.90549 25.291336" fill-rule="evenodd"/><path stroke="#ff9900" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m421.73184 92.78451l-90.10388 23.761383" fill-rule="evenodd"/><path fill="#ff9900" stroke="#ff9900" stroke-width="1.0" stroke-linecap="butt" d="m331.2068 114.94876l-3.9668884 2.7543106l4.809265 0.43994904z" fill-rule="evenodd"/></g></svg>
\ No newline at end of file
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/segmentation-output.png b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/segmentation-output.png
new file mode 100644
index 0000000..e871df3
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/segmentation-output.png
Binary files differ
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/sparrow.jpg b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/sparrow.jpg
new file mode 100644
index 0000000..25d213e
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/images/sparrow.jpg
Binary files differ
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md
new file mode 100644
index 0000000..c2c1b2a
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md
@@ -0,0 +1,151 @@
+# Integrate Natural language classifier
+
+The Task Library's `NLClassifier` API classifies input text into different
+categories, and is a versatile and configurable API that can handle most text
+classification models.
+
+## Key features of the NLClassifier API
+
+*   Takes a single string as input, performs classification with the string and
+    outputs <Label, Score> pairs as classification results.
+
+*   Optional Regex Tokenization available for input text.
+
+*   Configurable to adapt different classification models.
+
+## Supported NLClassifier models
+
+The following models are guaranteed to be compatible with the `NLClassifier`
+API.
+
+*   The <a href="../../models/text_classification/overview.md">movie review
+    sentiment classification</a> model.
+
+*   Models with `average_word_vec` spec created by
+    [TensorFlow Lite Model Maker for text Classfication](https://www.tensorflow.org/lite/tutorials/model_maker_text_classification).
+
+*   Custom models that meet the
+    [model compatibility requirements](#model-compatibility-requirements).
+
+## Run inference in Java
+
+### Step 1: Import Gradle dependency and other settings
+
+Copy the `.tflite` model file to the assets directory of the Android module
+where the model will be run. Specify that the file should not be compressed, and
+add the TensorFlow Lite library to the module’s `build.gradle` file:
+
+```java
+android {
+    // Other settings
+
+    // Specify tflite file should not be compressed for the app apk
+    aaptOptions {
+        noCompress "tflite"
+    }
+
+}
+
+dependencies {
+    // Other dependencies
+
+    // Import the Task Text Library dependency
+    implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
+}
+```
+
+### Step 2: Run inference using the API
+
+```java
+// Initialization, use NLClassifierOptions to configure input and output tensors
+NLClassifierOptions options = NLClassifierOptions.builder().setInputTensorName(INPUT_TENSOR_NAME).setOutputScoreTensorName(OUTPUT_SCORE_TENSOR_NAME).build();
+NLClassifier classifier = NLClassifier.createFromFileAndOptions(context, modelFile, options);
+
+// Run inference
+List<Category> results = classifier.classify(input);
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java)
+for more options to configure `NLClassifier`.
+
+## Run inference in C++
+
+Note: We are working on improving the usability of the C++ Task Library, such as
+providing prebuilt binaries and creating user-friendly workflows to build from
+source code. The C++ API may be subject to change.
+
+```c++
+// Initialization
+std::unique_ptr<NLClassifier> classifier = NLClassifier::CreateFromFileAndOptions(
+    model_path,
+    {
+      .input_tensor_name=kInputTensorName,
+      .output_score_tensor_name=kOutputScoreTensorName,
+    }).value();
+
+// Run inference
+std::vector<core::Category> categories = classifier->Classify(kInput);
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h)
+for more details.
+
+## Example results
+
+Here is an example of the classification results of the
+[movie review model](https://www.tensorflow.org/lite/models/text_classification/overview).
+
+Input: "What a waste of my time."
+
+Output:
+
+```
+category[0]: 'Negative' : '0.81313'
+category[1]: 'Positive' : '0.18687'
+```
+
+Try out the simple
+[CLI demo tool for NLClassifier](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/text/desktop/README.md#nlclassifier)
+with your own model and test data.
+
+## Model compatibility requirements
+
+Depending on the use case, the `NLClassifier` API can load a TFLite model with
+or without [TFLite Model Metadata](../../convert/metadata.md).
+
+The compatible models should meet the following requirements:
+
+*   Input tensor: (kTfLiteString/kTfLiteInt32)
+
+    -   Input of the model should be either a kTfLiteString tensor raw input
+        string or a kTfLiteInt32 tensor for regex tokenized indices of raw input
+        string.
+    -   If input type is kTfLiteString, no [Metadata](../../convert/metadata.md)
+        is required for the model.
+    -   If input type is kTfLiteInt32, a `RegexTokenizer` needs to be set up in
+        the input tensor's [Metadata](../../convert/metadata.md).
+
+*   Output score tensor:
+    (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64)
+
+    -   Mandatory output tensor for the score of each category classified.
+
+    -   If type is one of the Int types, dequantize it to double/float to
+        corresponding platforms
+
+    -   Can have an optional associated file in the output tensor's
+        corresponding [Metadata](../../convert/metadata.md) for category labels,
+        the file should be a plain text file with one label per line, and the
+        number of labels should match the number of categories as the model
+        outputs.
+
+*   Output label tensor: (kTfLiteString/kTfLiteInt32)
+
+    -   Optional output tensor for the label for each category, should be of the
+        same length as the output score tensor. If this tensor is not present,
+        the API uses score indices as classnames.
+
+    -   Will be ignored if the associated label file is present in output score
+        tensor's Metadata.
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md
new file mode 100644
index 0000000..09ce3a1
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/object_detector.md
@@ -0,0 +1,179 @@
+# Integrate object detectors
+
+Object detectors can identify which of a known set of objects might be present
+and provide information about their positions within the given image or a video
+stream. An object detector is trained to detect the presence and location of
+multiple classes of objects. For example, a model might be trained with images
+that contain various pieces of fruit, along with a _label_ that specifies the
+class of fruit they represent (e.g. an apple, a banana, or a strawberry), and
+data specifying where each object appears in the image. See the
+[introduction of object detection](../../models/object_detection/overview.md)
+for more information about object detectors.
+
+Use the Task Library `ObjectDetector` API to deploy your custom object detectors
+or pretrained ones into your model apps.
+
+## Key features of the ObjectDetector API
+
+*   Input image processing, including rotation, resizing, and color space
+    conversion.
+
+*   Label map locale.
+
+*   Score threshold to filter results.
+
+*   Top-k detection results.
+
+*   Label allowlist and denylist.
+
+## Supported object detector models
+
+The following models are guaranteed to be compatible with the `ObjectDetector`
+API.
+
+*   The
+    [pretrained object detection models on TensorFlow Hub](https://tfhub.dev/tensorflow/collections/lite/task-library/object-detector/1).
+
+*   Models created by
+    [AutoML Vision Edge Object Detection](https://cloud.google.com/vision/automl/object-detection/docs).
+
+*   Custom models that meet the
+    [model compatibility requirements](#model-compatibility-requirements).
+
+## Run inference in Java
+
+### Step 1: Import Gradle dependency and other settings
+
+Copy the `.tflite` model file to the assets directory of the Android module
+where the model will be run. Specify that the file should not be compressed, and
+add the TensorFlow Lite library to the module’s `build.gradle` file:
+
+```java
+android {
+    // Other settings
+
+    // Specify tflite file should not be compressed for the app apk
+    aaptOptions {
+        noCompress "tflite"
+    }
+
+}
+
+dependencies {
+    // Other dependencies
+
+    // Import the Task Vision Library dependency
+    implementation 'org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly'
+}
+```
+
+### Step 2: Using the model
+
+```java
+// Initialization
+ObjectDetectorOptions options = ObjectDetectorOptions.builder().setMaxResults(1).build();
+ObjectDetector objectDetector = ObjectDetector.createFromFileAndOptions(context, modelFile, options);
+
+// Run inference
+List<Detection> results = objectDetector.detect(image);
+```
+
+See the
+[source code and javadoc](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java)
+for more options to configure `ObjectDetector`.
+
+## Run inference in C++
+
+Note: we are working on improving the usability of the C++ Task Library, such as
+providing prebuilt binaries and creating user-friendly workflows to build from
+source code. The C++ API may be subject to change.
+
+```c++
+// Initialization
+ObjectDetectorOptions options;
+options.mutable_model_file_with_metadata()->set_file_name(model_file);
+std::unique_ptr<ObjectDetector> object_detector = ObjectDetector::CreateFromOptions(options).value();
+
+// Run inference
+const DetectionResult result = object_detector->Detect(*frame_buffer).value();
+```
+
+See the
+[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/vision/object_detector.h)
+for more options to configure `ObjectDetector`.
+
+## Example results
+
+Here is an example of the detection results of
+[ssd mobilenet v1](https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1)
+from TensorFlow Hub.
+
+<img src="images/dogs.jpg" alt="dogs" width="50%">
+
+```
+Results:
+ Detection #0 (red):
+  Box: (x: 355, y: 133, w: 190, h: 206)
+  Top-1 class:
+   index       : 17
+   score       : 0.73828
+   class name  : dog
+ Detection #1 (green):
+  Box: (x: 103, y: 15, w: 138, h: 369)
+  Top-1 class:
+   index       : 17
+   score       : 0.73047
+   class name  : dog
+```
+
+Render the bounding boxes onto the input image:
+
+<img src="images/detection-output.png" alt="detection output" width="50%">
+
+Try out the simple
+[CLI demo tool for ObjectDetector](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/vision/desktop#object-detector)
+with your own model and test data.
+
+## Model compatibility requirements
+
+The `ObjectDetector` API expects a TFLite model with mandatory
+[TFLite Model Metadata](../../convert/metadata.md).
+
+The compatible object detector models should meet the following requirements:
+
+*   Input image tensor: (kTfLiteUInt8/kTfLiteFloat32)
+
+    -   image input of size `[batch x height x width x channels]`.
+    -   batch inference is not supported (`batch` is required to be 1).
+    -   only RGB inputs are supported (`channels` is required to be 3).
+    -   if type is kTfLiteFloat32, NormalizationOptions are required to be
+        attached to the metadata for input normalization.
+
+*   Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e:
+
+    -   Locations tensor (kTfLiteFloat32)
+        -   tensor of size `[1 x num_results x 4]`, the inner array representing
+            bounding boxes in the form [top, left, right, bottom].
+        -   BoundingBoxProperties are required to be attached to the metadata
+            and must specify `type=BOUNDARIES` and `coordinate_type=RATIO.
+    -   Classes tensor (kTfLiteFloat32)
+
+        -   tensor of size `[1 x num_results]`, each value representing the
+            integer index of a class.
+        -   optional (but recommended) label map(s) can be attached as
+            AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label
+            per line. The first such AssociatedFile (if any) is used to fill the
+            `class_name` field of the results. The `display_name` field is
+            filled from the AssociatedFile (if any) whose locale matches the
+            `display_names_locale` field of the `ObjectDetectorOptions` used at
+            creation time ("en" by default, i.e. English). If none of these are
+            available, only the `index` field of the results will be filled.
+
+    -   Scores tensor (kTfLiteFloat32)
+
+        -   tensor of size `[1 x num_results]`, each value representing the
+            score of the detected object.
+
+    -   Number of detection tensor (kTfLiteFloat32)
+
+        -   integer num_results as a tensor of size `[1]`.
diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md
new file mode 100644
index 0000000..2a2b124
--- /dev/null
+++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md
@@ -0,0 +1,51 @@
+# TensorFlow Lite Task Library
+
+TensorFlow Lite Task Library contains a set of powerful and easy-to-use
+task-specific libraries for app developers to create ML experiences with TFLite.
+It provides optimized out-of-box model interfaces for popular machine learning
+tasks, such as image classification, question and answer, etc. The model
+interfaces are specifically designed for each task to achieve the best
+performance and usability. Task Library works cross-platform and is supported on
+Java, C++, and Swift (coming soon).
+
+## What to expect from the Task Library
+
+*   **Clean and well-defined APIs usable by non-ML-experts** \
+    Inference can be done within just 5 lines of code. Use the powerful and
+    easy-to-use APIs in the Task library as building blocks to help you easily
+    develop ML with TFLite on mobile devices.
+
+*   **Complex but common data processing** \
+    Supports common vision and natural language processing logic to convert
+    between your data and the data format required by the model. Provides the
+    same, shareable processing logic for training and inference.
+
+*   **High performance gain** \
+    Data processing would take no more than a few milliseconds, ensuring the
+    fast inference experience using TensorFlow Lite.
+
+*   **Extensibility and customization** \
+    You can leverage all benefits the Task Library infrastructure provides and
+    easily build your own Android/iOS inference APIs.
+
+## Supported tasks
+
+Below is the list of the supported task types. The list is expected to grow as
+we continue enabling more and more use cases.
+
+*   **Vision APIs**
+
+    *   [ImageClassifier](image_classifier.md)
+    *   [ObjectDetector](object_detector.md)
+    *   [ImageSegmenter](image_segmenter.md)
+
+*   **Natural Language (NL) APIs**
+
+    *   [NLClassifier](nl_classifier.md)
+    *   [BertNLCLassifier](bert_nl_classifier.md)
+    *   [BertQuestionAnswerer](bert_question_answerer.md)
+
+*   **Custom APIs**
+
+    *   Extend Task API infrastructure and build
+        [customized API](customized_task_api.md).
diff --git a/tensorflow/lite/g3doc/performance/benchmarks.md b/tensorflow/lite/g3doc/performance/benchmarks.md
deleted file mode 100644
index 7b1eb5c..0000000
--- a/tensorflow/lite/g3doc/performance/benchmarks.md
+++ /dev/null
@@ -1,204 +0,0 @@
-# Performance benchmarks
-
-This document lists TensorFlow Lite performance benchmarks when running well
-known models on some Android and iOS devices.
-
-These performance benchmark numbers were generated with the
-[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark)
-and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios).
-
-## Android performance benchmarks
-
-For Android benchmarks, the CPU affinity is set to use big cores on the device to
-reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
-
-It assumes that models were download and unzipped to the
-`/data/local/tmp/tflite_models` directory. The benchmark binary is built
-using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#on-android)
-and assumed in the `/data/local/tmp` directory.
-
-To run the benchmark:
-
-```sh
-adb shell /data/local/tmp/benchmark_model \
-  --num_threads=4 \
-  --graph=/data/local/tmp/tflite_models/${GRAPH} \
-  --warmup_runs=1 \
-  --num_runs=50
-```
-
-To run with nnapi delegate, please set `--use_nnapi=true`. To run with gpu
-delegate, please set `--use_gpu=true`.
-
-The performance values below are measured on Android 10.
-
-<table>
-  <thead>
-    <tr>
-      <th>Model Name</th>
-      <th>Device </th>
-      <th>CPU, 4 threads</th>
-      <th>GPU</th>
-      <th>NNAPI</th>
-    </tr>
-  </thead>
-  <tr>
-    <td rowspan = 2>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
-    </td>
-    <td>Pixel 3 </td>
-    <td>23.9 ms</td>
-    <td>6.45 ms</td>
-    <td>13.8 ms</td>
-  </tr>
-   <tr>
-     <td>Pixel 4 </td>
-    <td>14.0 ms</td>
-    <td>9.0 ms</td>
-    <td>14.8 ms</td>
-  </tr>
-  <tr>
-    <td rowspan = 2>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
-    </td>
-    <td>Pixel 3 </td>
-    <td>13.4 ms</td>
-    <td>--- </td>
-    <td>6.0 ms</td>
-  </tr>
-   <tr>
-     <td>Pixel 4 </td>
-    <td>5.0 ms</td>
-    <td>--- </td>
-    <td>3.2 ms</td>
-  </tr>
-  <tr>
-    <td rowspan = 2>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
-    </td>
-    <td>Pixel 3 </td>
-    <td>56 ms</td>
-    <td>--- </td>
-    <td>102 ms</td>
-  </tr>
-   <tr>
-     <td>Pixel 4 </td>
-    <td>34.5 ms</td>
-    <td>--- </td>
-    <td>99.0 ms</td>
-  </tr>
-  <tr>
-    <td rowspan = 2>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
-    </td>
-    <td>Pixel 3 </td>
-    <td>35.8 ms</td>
-    <td>9.5 ms </td>
-    <td>18.5 ms</td>
-  </tr>
-   <tr>
-     <td>Pixel 4 </td>
-    <td>23.9 ms</td>
-    <td>11.1 ms</td>
-    <td>19.0 ms</td>
-  </tr>
-  <tr>
-    <td rowspan = 2>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
-    </td>
-    <td>Pixel 3 </td>
-    <td>422 ms</td>
-    <td>99.8 ms </td>
-    <td>201 ms</td>
-  </tr>
-   <tr>
-     <td>Pixel 4 </td>
-    <td>272.6 ms</td>
-    <td>87.2 ms</td>
-    <td>171.1 ms</td>
-  </tr>
-  <tr>
-    <td rowspan = 2>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
-    </td>
-    <td>Pixel 3 </td>
-    <td>486 ms</td>
-    <td>93 ms </td>
-    <td>292 ms</td>
-  </tr>
-   <tr>
-     <td>Pixel 4 </td>
-    <td>324.1 ms</td>
-    <td>97.6 ms</td>
-    <td>186.9 ms</td>
-  </tr>
-
- </table>
-
-## iOS benchmarks
-
-To run iOS benchmarks, the
-[benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios)
-was modified to include the appropriate model and `benchmark_params.json` was
-modified to set `num_threads` to 2. For GPU delegate, `"use_gpu" : "1"` and
-`"gpu_wait_type" : "aggressive"` options were also added to
-`benchmark_params.json`.
-
-<table>
-  <thead>
-    <tr>
-      <th>Model Name</th>
-      <th>Device </th>
-      <th>CPU, 2 threads</th>
-      <th>GPU</th>
-    </tr>
-  </thead>
-  <tr>
-    <td>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
-    </td>
-    <td>iPhone XS </td>
-    <td>14.8 ms</td>
-    <td>3.4 ms</td>
-  </tr>
-  <tr>
-    <td>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
-    </td>
-    <td>iPhone XS </td>
-    <td>11 ms</td>
-    <td>---</td>
-  </tr>
-  <tr>
-    <td>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
-    </td>
-    <td>iPhone XS </td>
-    <td>30.4 ms</td>
-    <td>---</td>
-  </tr>
-  <tr>
-    <td>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
-    </td>
-    <td>iPhone XS </td>
-    <td>21.1 ms</td>
-    <td>15.5 ms</td>
-  </tr>
-  <tr>
-    <td>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
-    </td>
-    <td>iPhone XS </td>
-    <td>261.1 ms</td>
-    <td>45.7 ms</td>
-  </tr>
-  <tr>
-    <td>
-      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
-    </td>
-    <td>iPhone XS </td>
-    <td>309 ms</td>
-    <td>54.4 ms</td>
-  </tr>
- </table>
diff --git a/tensorflow/lite/g3doc/performance/best_practices.md b/tensorflow/lite/g3doc/performance/best_practices.md
index e4abb56..9df0ace 100644
--- a/tensorflow/lite/g3doc/performance/best_practices.md
+++ b/tensorflow/lite/g3doc/performance/best_practices.md
@@ -38,6 +38,12 @@
 help in understanding performance bottlenecks and which operators dominate the
 computation time.
 
+You can also use
+[TensrFlow Lite tracing](measurement.md#trace_tensorflow_lite_internals_in_android)
+to profile the model in your Android application, using standard Android system
+tracing, and to visualize the operator invocations by time with GUI based
+profiling tools.
+
 ## Profile and optimize operators in the graph
 
 If a particular operator appears frequently in the model and, based on
@@ -116,7 +122,7 @@
 
 Be aware that some accelerators work better for different types of models. Some
 delegates only support float models or models optimized in a specific way. It is
-important to [benchmark](benchmarks.md) each delegate to see if it is a good
+important to [benchmark](measurement.md) each delegate to see if it is a good
 choice for your application. For example, if you have a very small model, it may
 not be worth delegating the model to either the NN API or the GPU. Conversely,
 accelerators are a great choice for large models that have high arithmetic
diff --git a/tensorflow/lite/g3doc/performance/images/as_select_profiling_mode.png b/tensorflow/lite/g3doc/performance/images/as_select_profiling_mode.png
new file mode 100644
index 0000000..9ba5ba8
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/images/as_select_profiling_mode.png
Binary files differ
diff --git a/tensorflow/lite/g3doc/performance/images/as_traces.png b/tensorflow/lite/g3doc/performance/images/as_traces.png
new file mode 100644
index 0000000..cbc2b14
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/images/as_traces.png
Binary files differ
diff --git a/tensorflow/lite/g3doc/performance/images/perfetto_traces.png b/tensorflow/lite/g3doc/performance/images/perfetto_traces.png
new file mode 100644
index 0000000..94b2654
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/images/perfetto_traces.png
Binary files differ
diff --git a/tensorflow/lite/g3doc/performance/measurement.md b/tensorflow/lite/g3doc/performance/measurement.md
new file mode 100644
index 0000000..9d2f724
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/measurement.md
@@ -0,0 +1,524 @@
+# Performance measurement
+
+## Benchmark tools
+
+TensorFlow Lite benchmark tools currently measure and calculate statistics for
+the following important performance metrics:
+
+*   Initialization time
+*   Inference time of warmup state
+*   Inference time of steady state
+*   Memory usage during initialization time
+*   Overall memory usage
+
+The benchmark tools are available as benchmark apps for Android and iOS and as
+native command-line binaries, and they all share the same core performance
+measurement logic. Note that the available options and output formats are
+slightly different due to the differences in runtime environment.
+
+### Android benchmark app
+
+There are two options of using the benchmark tool with Android. One is a
+[native benchmark binary](#native-benchmark-binary) and another is an Android
+benchmark app, a better gauge of how the model would perform in the app. Either
+way, the numbers from the benchmark tool will still differ slightly from when
+running inference with the model in the actual app.
+
+This Android benchmark app has no UI. Install and run it by using the `adb`
+command and retrieve results by using the `adb logcat` command.
+
+#### Download or build the app
+
+Download the nightly pre-built Android benchmark apps using the links below:
+
+*   [android_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model.apk)
+
+*   [android_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model.apk)
+
+You can also build the app from source by following these
+[instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/android).
+
+Note: It is required to build the app from the source if you want to run the
+Android benchmark apk on x86 CPU or Hexagon delegate or if your model contains
+[select TF operators](../guide/ops_select) or
+[custom operators](../guide/ops_custom).
+
+#### Prepare benchmark
+
+Before running the benchmark app, install the app and push the model file to the
+device as follows:
+
+```shell
+adb install -r -d -g android_aarch64_benchmark_model.apk
+adb push your_model.tflite /data/local/tmp
+```
+
+#### Run benchmark
+
+```shell
+adb shell am start -S \
+  -n org.tensorflow.lite.benchmark/.BenchmarkModelActivity \
+  --es args '"--graph=/data/local/tmp/your_model.tflite \
+              --num_threads=4"'
+```
+
+`graph` is a required parameter.
+
+*   `graph`: `string` \
+    The path to the TFLite model file.
+
+You can specify more optional parameters for running the benchmark.
+
+*   `num_threads`: `int` (default=1) \
+    The number of threads to use for running TFLite interpreter.
+*   `use_gpu`: `bool` (default=false) \
+    Use [GPU delegate](gpu).
+*   `use_nnapi`: `bool` (default=false) \
+    Use [NNAPI delegate](nnapi).
+*   `use_xnnpack`: `bool` (default=`false`) \
+    Use
+    [XNNPACK delegate](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack).
+*   `use_hexagon`: `bool` (default=`false`) \
+    Use [Hexagon delegate](hexagon_delegate).
+
+Depending on the device you are using, some of these options may not be
+available or have no effect. Refer to
+[parameters](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#parameters)
+for more performance parameters that you could run with the benchmark app.
+
+View the results using the `logcat` command:
+
+```shell
+adb logcat | grep "Average inference"
+```
+
+The benchmark results are reported as:
+
+```
+... tflite  : Average inference timings in us: Warmup: 91471, Init: 4108, Inference: 80660.1
+```
+
+### Native benchmark binary
+
+Benchmark tool is also provided as a native binary `benchmark_model`. You can
+execute this tool from a shell command line on Linux, Mac, embedded devices and
+Android devices.
+
+#### Download or build the binary
+
+Download the nightly pre-built native command-line binaries by following the
+links below:
+
+*   [linux_x86-64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_x86-64_benchmark_model)
+*   [linux_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_aarch64_benchmark_model)
+*   [linux_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_arm_benchmark_model)
+*   [android_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model)
+*   [android_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model)
+
+You can also build the native benchmark binary from
+[source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark)
+on your computer.
+
+```shell
+bazel build -c opt //tensorflow/lite/tools/benchmark:benchmark_model
+```
+
+To build with Android NDK toolchain, you need to set up the build environment
+first by following this
+[guide](../guide/build_android#set_up_build_environment_without_docker), or use
+the docker image as described in this
+[guide](../guide/build_android#set_up_build_environment_using_docker).
+
+```shell
+bazel build -c opt --config=android_arm64 \
+  //tensorflow/lite/tools/benchmark:benchmark_model
+```
+
+Note: It is a valid approach to push and execute binaries directly on an Android
+device for benchmarking, but it can result in subtle (but observable)
+differences in performance relative to execution within an actual Android app.
+In particular, Android's scheduler tailors behavior based on thread and process
+priorities, which differ between a foreground Activity/Application and a regular
+background binary executed via `adb shell ...`. This tailored behavior is most
+evident when enabling multi-threaded CPU execution with TensorFlow Lite.
+Therefore, the Android benchmark app is preferred for performance measurement.
+
+#### Run benchmark
+
+To run benchmarks on your computer, execute the binary from the shell.
+
+```shell
+path/to/downloaded_or_built/benchmark_model \
+  --graph=your_model.tflite \
+  --num_threads=4
+```
+
+You can use the same set of
+[parameters](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#parameters)
+as mentioned above with the native command-line binary.
+
+#### Profiling model ops
+
+The benchmark model binary also allows you to profile model ops and get the
+execution times of each operator. To do this, pass the flag
+`--enable_op_profiling=true` to `benchmark_model` during invocation. Details are
+explained
+[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#profiling-model-operators).
+
+### Native benchmark binary for multiple performance options in a single run
+
+A convenient and simple C++ binary is also provided to
+[benchmark multiple performance options](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#benchmark-multiple-performance-options-in-a-single-run)
+in a single run. This binary is built based on the aforementioned benchmark tool
+that could only benchmark a single performance option at a time. They share the
+same build/install/run process, but the BUILD target name of this binary is
+`benchmark_model_performance_options` and it takes some additional parameters.
+An important parameter for this binary is:
+
+`perf_options_list`: `string` (default='all') \
+A comma-separated list of TFLite performance options to benchmark.
+
+You can get nightly pre-built binaries for this tool as listed below:
+
+*   [linux_x86-64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_x86-64_benchmark_model_performance_options)
+*   [linux_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_aarch64_benchmark_model_performance_options)
+*   [linux_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_arm_benchmark_model_performance_options)
+*   [android_aarch64](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model_performance_options)
+*   [android_arm](https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_arm_benchmark_model_performance_options)
+
+### iOS benchamark app
+
+To run benchmarks on iOS device, you need to build the app from
+[source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios).
+Put the TensorFlow Lite model file in the
+[benchmark_data](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data)
+directory of the source tree and modify the `benchmark_params.json` file. Those
+files are packaged into the app and the app reads data from the directory. Visit
+the
+[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios)
+for detailed instructions.
+
+## Performance benchmarks for well known models
+
+This section lists TensorFlow Lite performance benchmarks when running well
+known models on some Android and iOS devices.
+
+### Android performance benchmarks
+
+These performance benchmark numbers were generated with the
+[native benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark).
+
+For Android benchmarks, the CPU affinity is set to use big cores on the device
+to reduce variance (see
+[details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
+
+It assumes that models were downloaded and unzipped to the
+`/data/local/tmp/tflite_models` directory. The benchmark binary is built using
+[these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#on-android)
+and assumed to be in the `/data/local/tmp` directory.
+
+To run the benchmark:
+
+```sh
+adb shell /data/local/tmp/benchmark_model \
+  --num_threads=4 \
+  --graph=/data/local/tmp/tflite_models/${GRAPH} \
+  --warmup_runs=1 \
+  --num_runs=50
+```
+
+To run with nnapi delegate, set `--use_nnapi=true`. To run with GPU delegate,
+set `--use_gpu=true`.
+
+The performance values below are measured on Android 10.
+
+<table>
+  <thead>
+    <tr>
+      <th>Model Name</th>
+      <th>Device </th>
+      <th>CPU, 4 threads</th>
+      <th>GPU</th>
+      <th>NNAPI</th>
+    </tr>
+  </thead>
+  <tr>
+    <td rowspan = 2>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+    </td>
+    <td>Pixel 3 </td>
+    <td>23.9 ms</td>
+    <td>6.45 ms</td>
+    <td>13.8 ms</td>
+  </tr>
+   <tr>
+     <td>Pixel 4 </td>
+    <td>14.0 ms</td>
+    <td>9.0 ms</td>
+    <td>14.8 ms</td>
+  </tr>
+  <tr>
+    <td rowspan = 2>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
+    </td>
+    <td>Pixel 3 </td>
+    <td>13.4 ms</td>
+    <td>--- </td>
+    <td>6.0 ms</td>
+  </tr>
+   <tr>
+     <td>Pixel 4 </td>
+    <td>5.0 ms</td>
+    <td>--- </td>
+    <td>3.2 ms</td>
+  </tr>
+  <tr>
+    <td rowspan = 2>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+    </td>
+    <td>Pixel 3 </td>
+    <td>56 ms</td>
+    <td>--- </td>
+    <td>102 ms</td>
+  </tr>
+   <tr>
+     <td>Pixel 4 </td>
+    <td>34.5 ms</td>
+    <td>--- </td>
+    <td>99.0 ms</td>
+  </tr>
+  <tr>
+    <td rowspan = 2>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+    </td>
+    <td>Pixel 3 </td>
+    <td>35.8 ms</td>
+    <td>9.5 ms </td>
+    <td>18.5 ms</td>
+  </tr>
+   <tr>
+     <td>Pixel 4 </td>
+    <td>23.9 ms</td>
+    <td>11.1 ms</td>
+    <td>19.0 ms</td>
+  </tr>
+  <tr>
+    <td rowspan = 2>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+    </td>
+    <td>Pixel 3 </td>
+    <td>422 ms</td>
+    <td>99.8 ms </td>
+    <td>201 ms</td>
+  </tr>
+   <tr>
+     <td>Pixel 4 </td>
+    <td>272.6 ms</td>
+    <td>87.2 ms</td>
+    <td>171.1 ms</td>
+  </tr>
+  <tr>
+    <td rowspan = 2>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+    </td>
+    <td>Pixel 3 </td>
+    <td>486 ms</td>
+    <td>93 ms </td>
+    <td>292 ms</td>
+  </tr>
+   <tr>
+     <td>Pixel 4 </td>
+    <td>324.1 ms</td>
+    <td>97.6 ms</td>
+    <td>186.9 ms</td>
+  </tr>
+
+ </table>
+
+### iOS performance benchmarks
+
+These performance benchmark numbers were generated with the
+[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/ios).
+
+To run iOS benchmarks, the benchmark app was modified to include the appropriate
+model and `benchmark_params.json` was modified to set `num_threads` to 2. To use
+the GPU delegate, `"use_gpu" : "1"` and `"gpu_wait_type" : "aggressive"` options
+were also added to `benchmark_params.json`.
+
+<table>
+  <thead>
+    <tr>
+      <th>Model Name</th>
+      <th>Device </th>
+      <th>CPU, 2 threads</th>
+      <th>GPU</th>
+    </tr>
+  </thead>
+  <tr>
+    <td>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+    </td>
+    <td>iPhone XS </td>
+    <td>14.8 ms</td>
+    <td>3.4 ms</td>
+  </tr>
+  <tr>
+    <td>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
+    </td>
+    <td>iPhone XS </td>
+    <td>11 ms</td>
+    <td>---</td>
+  </tr>
+  <tr>
+    <td>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+    </td>
+    <td>iPhone XS </td>
+    <td>30.4 ms</td>
+    <td>---</td>
+  </tr>
+  <tr>
+    <td>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+    </td>
+    <td>iPhone XS </td>
+    <td>21.1 ms</td>
+    <td>15.5 ms</td>
+  </tr>
+  <tr>
+    <td>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+    </td>
+    <td>iPhone XS </td>
+    <td>261.1 ms</td>
+    <td>45.7 ms</td>
+  </tr>
+  <tr>
+    <td>
+      <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+    </td>
+    <td>iPhone XS </td>
+    <td>309 ms</td>
+    <td>54.4 ms</td>
+  </tr>
+ </table>
+
+## Trace TensorFlow Lite internals in Android
+
+Note: This feature is experimental and available only when the Android app is
+built with the nightly released Tensorflow Lite library. Stable libraries up to
+v2.3 do not support this.
+
+Internal events from the TensorFlow Lite interpreter of an Android app can be
+captured by
+[Android tracing tools](https://developer.android.com/topic/performance/tracing).
+It is the same event with Android
+[Trace](https://developer.android.com/reference/android/os/Trace) API, so the
+captured events from Java/Kotlin code are seen together with TensorFlow Lite
+internal events.
+
+Some examples of events are:
+
+*   Operator invocation
+*   Graph modification by deleagate
+*   Tensor allocation
+
+Among different options for capturing traces, this guide covers the Android
+Studio CPU Profiler and the System Tracing app. Refer to
+[Perfetto command-line tool](https://developer.android.com/studio/command-line/perfetto)
+or
+[Systrace command-line tool](https://developer.android.com/topic/performance/tracing/command-line)
+for other options.
+
+### Adding trace events in Java code
+
+This is a code snippet from the
+[Image Classification](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android)
+example app. TensorFlow Lite interpreter runs in the
+`recognizeImage/runInference` section. This step is optional but it is useful to
+help notice where the inference call is made.
+
+```java
+  Trace.beginSection("recognizeImage");
+  ...
+  // Runs the inference call.
+  Trace.beginSection("runInference");
+  tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
+  Trace.endSection();
+  ...
+  Trace.endSection();
+
+```
+
+### Enable TensorFlow Lite tracing
+
+To enable TensorFlow Lite tracing, set the Android system property
+`debug.tflite.tracing` to 1 before starting the Android app.
+
+```shell
+adb shell setprop debug.tflite.trace 1
+```
+
+If this property has been set when TensorFlow Lite interpreter is initialized,
+key events (e.g., operator invocation) from the interpreter will be traced.
+
+After you captured all the traces, disable tracing by setting the property value
+to 0.
+
+```shell
+adb shell setprop debug.tflite.trace 0
+```
+
+### Android Studio CPU Profiler
+
+Capture traces with the
+[Android Studio CPU Profiler](https://developer.android.com/studio/profile/cpu-profiler)
+by following the steps below:
+
+1.  Select **Run > Profile 'app'** from the top menus.
+
+2.  Click anywhere in CPU timeline when the Profiler window appears.
+
+3.  Select 'Trace System Calls' among CPU Profiling modes.
+
+    ![Select 'Trace System Calls'](images/as_select_profiling_mode.png)
+
+4.  Press 'Record' button.
+
+5.  Press 'Stop' button.
+
+6.  Investigate the trace result.
+
+    ![Android Studio trace](images/as_traces.png)
+
+In this example, you can see the hierarchy of events in a thread and statistics
+for each operator time and also see the data flow of the whole app among
+threads.
+
+### System Tracing app
+
+Capture traces without Android Studio by following the steps detailed in
+[System Tracing app](https://developer.android.com/topic/performance/tracing/on-device).
+
+In this example, the same TFLite events were captured and saved to the Perfetto
+or Systrace format depending on the version of Android device. The captured
+trace files can be opened in the [Perfetto UI](https://ui.perfetto.dev/#!/).
+
+![Perfetto trace](images/perfetto_traces.png)
+
+### Using the tracing data
+
+The tracing data allows you to identify performance bottlenecks.
+
+Here are some examples of insights that you can get from the profiler and
+potential solutions to improve performance:
+
+*   If the number of available CPU cores is smaller than the number of inference
+    threads, then the CPU scheduling overhead can lead to subpar performance.
+    You can reschedule other CPU intensive tasks in your application to avoid
+    overlapping with your model inference or tweak the number of interpreter
+    threads.
+*   If the operators are not fully delegated, then some parts of the model graph
+    are executed on the CPU rather than the expected hardware accelerator. You
+    can substitute the unsupported operators with similar supported operators.
diff --git a/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
index a2835f5..ab3a2ef 100644
--- a/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
+++ b/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
@@ -679,12 +679,6 @@
         "    interpreter.invoke()\n",
         "    output = interpreter.get_tensor(output_details[\"index\"])[0]\n",
         "\n",
-        "    # Check if the output type is quantized, then rescale output data to float\n",
-        "    if output_details['dtype'] == np.uint8:\n",
-        "      output_scale, output_zero_point = output_details[\"quantization\"]\n",
-        "      test_image = test_image.astype(np.float32)\n",
-        "      test_image = test_image / input_scale + input_zero_point\n",
-        "\n",
         "    predictions[i] = output.argmax()\n",
         "\n",
         "  return predictions\n"
diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md
index 6198798..5bfe60e 100644
--- a/tensorflow/lite/g3doc/performance/post_training_quantization.md
+++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md
@@ -89,6 +89,9 @@
 [TensorFlow Lite for Microcontrollers](https://www.tensorflow.org/lite/microcontrollers)
 and [Coral Edge TPUs](https://coral.ai/).*
 
+Note: Starting TensorFlow 2.3.0, we support the `inference_input_type` and
+`inference_output_type` attributes.
+
 Additionally, to ensure compatibility with integer only devices (such as 8-bit
 microcontrollers) and accelerators (such as the Coral Edge TPU), you can enforce
 full integer quantization for all ops including the input and output, by using
diff --git a/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md b/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md
index 4c001bc..6cb2018 100644
--- a/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md
+++ b/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md
@@ -2,175 +2,165 @@
 
 This page shows how to use the TensorFlow Lite Converter in the command line.
 
+_Note: If possible, use the **recommended** [Python API](python_api.md)
+instead._
+
 ## Command-line tools <a name="tools"></a>
 
+### Starting from TensorFlow 1.9
+
 There are two approaches to running the converter in the command line.
 
-*   `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool
-    `tflite_convert` is installed as part of the Python package. All of the
-    examples below use `tflite_convert` for simplicity.
-    *   Example: `tflite_convert --output_file=...`
-*   `bazel`: In order to run the latest version of the TensorFlow Lite Converter
-    either install the nightly build using
-    [pip](https://www.tensorflow.org/install/pip) or
-    [clone the TensorFlow repository](https://www.tensorflow.org/install/source)
-    and use `bazel`.
-    *   Example: `bazel run
-        //third_party/tensorflow/lite/python:tflite_convert --
+*   `tflite_convert` (**recommended**):
+    *   *Install*: TensorFlow using
+        [pip](https://www.tensorflow.org/install/pip).
+    *   *Example*: `tflite_convert --output_file=...`
+*   `bazel`:
+    *   *Install*: TensorFlow from
+        [source](https://www.tensorflow.org/install/source).
+    *   *Example*: `bazel run tensorflow/lite/python:tflite_convert --
         --output_file=...`
 
-### Converting models prior to TensorFlow 1.9 <a name="pre_tensorflow_1.9"></a>
+*All of the following examples use `tflite_convert` for simplicity.
+Alternatively, you can replace '`tflite_convert`' with '`bazel run
+tensorflow/lite/python:tflite_convert --`'*
+
+### Prior to TensorFlow 1.9 <a name="pre_tensorflow_1.9"></a>
 
 The recommended approach for using the converter prior to TensorFlow 1.9 is the
-[Python API](python_api.md#pre_tensorflow_1.9). If a command line tool is
-desired, the `toco` command line tool was available in TensorFlow 1.7. Enter
-`toco --help` in Terminal for additional details on the command-line flags
-available. There were no command line tools in TensorFlow 1.8.
+[Python API](python_api.md). Only in TensorFlow 1.7, a command line tool `toco`
+was available (run `toco --help` for additional details).
 
-## Basic examples <a name="basic"></a>
+## Usage <a name="usage"></a>
 
-The following section shows examples of how to convert a basic float-point model
-from each of the supported data formats into a TensorFlow Lite FlatBuffers.
+### Setup <a name="download_models"></a>
 
-### Convert a TensorFlow GraphDef <a name="graphdef"></a>
-
-The follow example converts a basic TensorFlow GraphDef (frozen by
-[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py))
-into a TensorFlow Lite FlatBuffer to perform floating-point inference. Frozen
-graphs contain the variables stored in Checkpoint files as Const ops.
+Before we begin, download the models required to run the examples in this
+document:
 
 ```
+echo "Download MobileNet V1"
 curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
   | tar xzv -C /tmp
+
+echo "Download Inception V1"
+curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
+  | tar xzv -C /tmp
+```
+
+### Basic examples <a name="basic"></a>
+
+The following section shows examples of how to convert a basic model from each
+of the supported data formats into a TensorFlow Lite model.
+
+#### Convert a SavedModel <a name="savedmodel"></a>
+
+```
 tflite_convert \
-  --output_file=/tmp/foo.tflite \
+  --saved_model_dir=/tmp/saved_model \
+  --output_file=/tmp/foo.tflite
+```
+
+#### Convert a tf.keras model <a name="keras"></a>
+
+```
+tflite_convert \
+  --keras_model_file=/tmp/keras_model.h5 \
+  --output_file=/tmp/foo.tflite
+```
+
+#### Convert a Frozen GraphDef <a name="graphdef"></a>
+
+```
+tflite_convert \
   --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+  --output_file=/tmp/foo.tflite \
   --input_arrays=input \
   --output_arrays=MobilenetV1/Predictions/Reshape_1
 ```
 
-The value for `input_shapes` is automatically determined whenever possible.
+Frozen GraphDef models (or frozen graphs) are produced by
+[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)
+and require additional flags `--input_arrays` and `--output_arrays` as this
+information is not stored in the model format.
 
-### Convert a TensorFlow SavedModel <a name="savedmodel"></a>
+### Advanced examples
 
-The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite
-FlatBuffer to perform floating-point inference.
+#### Convert a quantization aware trained model into a quantized TensorFlow Lite model
+
+If you have a quantization aware trained model (i.e, a model inserted with
+`FakeQuant*` operations which record the (min, max) ranges of tensors in order
+to quantize them), then convert it into a quantized TensorFlow Lite model as
+shown below:
 
 ```
 tflite_convert \
+  --graph_def_file=/tmp/some_mobilenetv1_quantized_frozen_graph.pb \
   --output_file=/tmp/foo.tflite \
-  --saved_model_dir=/tmp/saved_model
-```
-
-[SavedModel](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators)
-has fewer required flags than frozen graphs due to access to additional data
-contained within the SavedModel. The values for `--input_arrays` and
-`--output_arrays` are an aggregated, alphabetized list of the inputs and outputs
-in the [SignatureDefs](../../serving/signature_defs.md) within
-the
-[MetaGraphDef](https://www.tensorflow.org/saved_model#apis_to_build_and_load_a_savedmodel)
-specified by `--saved_model_tag_set`. As with the GraphDef, the value for
-`input_shapes` is automatically determined whenever possible.
-
-There is currently no support for MetaGraphDefs without a SignatureDef or for
-MetaGraphDefs that use the [`assets/`
-directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory).
-
-### Convert a tf.Keras model <a name="keras"></a>
-
-The following example converts a `tf.keras` model into a TensorFlow Lite
-Flatbuffer. The `tf.keras` file must contain both the model and the weights.
-
-```
-tflite_convert \
-  --output_file=/tmp/foo.tflite \
-  --keras_model_file=/tmp/keras_model.h5
-```
-
-## Quantization
-
-### Convert a TensorFlow GraphDef for quantized inference <a name="graphdef_quant"></a>
-
-The TensorFlow Lite Converter is compatible with fixed point quantization models
-described
-[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/README.md).
-These are float models with `FakeQuant*` ops inserted at the boundaries of fused
-layers to record min-max range information. This generates a quantized inference
-workload that reproduces the quantization behavior that was used during
-training.
-
-The following command generates a quantized TensorFlow Lite FlatBuffer from a
-"quantized" TensorFlow GraphDef.
-
-```
-tflite_convert \
-  --output_file=/tmp/foo.tflite \
-  --graph_def_file=/tmp/some_quantized_graph.pb \
-  --inference_type=QUANTIZED_UINT8 \
   --input_arrays=input \
   --output_arrays=MobilenetV1/Predictions/Reshape_1 \
-  --mean_values=128 \
-  --std_dev_values=127
+  --inference_type=INT8 \
+  --mean_values=-0.5 \
+  --std_dev_values=127.7
 ```
 
-### Use \"dummy-quantization\" to try out quantized inference on a float graph <a name="dummy_quant"></a>
+*If you're setting `--inference_type=QUANTIZED_UINT8` then update
+`--mean_values=128` and `--std_dev_values=127`*
 
-In order to evaluate the possible benefit of generating a quantized graph, the
-converter allows "dummy-quantization" on float graphs. The flags
-`--default_ranges_min` and `--default_ranges_max` accept plausible values for
-the min-max ranges of the values in all arrays that do not have min-max
-information. "Dummy-quantization" will produce lower accuracy but will emulate
-the performance of a correctly quantized model.
+#### Convert a model with \"dummy-quantization\" into a quantized TensorFlow Lite model
+
+If you have a regular float model and only want to estimate the benefit of a
+quantized model, i.e, estimate the performance of the model as if it were
+quantized aware trained, then perform "dummy-quantization" using the flags
+`--default_ranges_min` and `--default_ranges_max`. When specified, they will be
+used as default (min, max) range for all the tensors that lack (min, max) range
+information. This will allow quantization to proceed and help you emulate the
+performance of a quantized TensorFlow Lite model but it will have a lower
+accuracy.
 
 The example below contains a model using Relu6 activation functions. Therefore,
 a reasonable guess is that most activation ranges should be contained in [0, 6].
 
 ```
-curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
-  | tar xzv -C /tmp
 tflite_convert \
-  --output_file=/tmp/foo.cc \
   --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
-  --inference_type=QUANTIZED_UINT8 \
+  --output_file=/tmp/foo.tflite \
   --input_arrays=input \
   --output_arrays=MobilenetV1/Predictions/Reshape_1 \
+  --inference_type=INT8 \
+  --mean_values=-0.5 \
+  --std_dev_values=127.7
   --default_ranges_min=0 \
   --default_ranges_max=6 \
-  --mean_values=128 \
-  --std_dev_values=127
 ```
 
-## Specifying input and output arrays
+*If you're setting `--inference_type=QUANTIZED_UINT8` then update
+`--mean_values=128` and `--std_dev_values=127`*
 
-### Multiple input arrays
+#### Convert a model with multiple input arrays
 
 The flag `input_arrays` takes in a comma-separated list of input arrays as seen
 in the example below. This is useful for models or subgraphs with multiple
-inputs.
+inputs. Note that `--input_shapes` is provided as a colon-separated list. Each
+input shape corresponds to the input array at the same position in the
+respective list.
 
 ```
-curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
-  | tar xzv -C /tmp
 tflite_convert \
   --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \
   --output_file=/tmp/foo.tflite \
-  --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
   --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \
+  --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
   --output_arrays=InceptionV1/Logits/Predictions/Reshape_1
 ```
 
-Note that `input_shapes` is provided as a colon-separated list. Each input shape
-corresponds to the input array at the same position in the respective list.
+#### Convert a model with multiple output arrays
 
-### Multiple output arrays
-
-The flag `output_arrays` takes in a comma-separated list of output arrays as
+The flag `--output_arrays` takes in a comma-separated list of output arrays as
 seen in the example below. This is useful for models or subgraphs with multiple
 outputs.
 
 ```
-curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
-  | tar xzv -C /tmp
 tflite_convert \
   --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \
   --output_file=/tmp/foo.tflite \
@@ -178,50 +168,45 @@
   --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu
 ```
 
-### Specifying subgraphs
+### Convert a model by specifying subgraphs
 
 Any array in the input file can be specified as an input or output array in
-order to extract subgraphs out of an input graph file. The TensorFlow Lite
-Converter discards the parts of the graph outside of the specific subgraph. Use
-[graph visualizations](#graph_visualizations) to identify the input and output
-arrays that make up the desired subgraph.
+order to extract subgraphs out of an input model file. The TensorFlow Lite
+Converter discards the parts of the model outside of the specific subgraph. Use
+[visualization](#visualization) to identify the input and output arrays that
+make up the desired subgraph.
 
 The follow command shows how to extract a single fused layer out of a TensorFlow
 GraphDef.
 
 ```
-curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
-  | tar xzv -C /tmp
 tflite_convert \
   --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \
   --output_file=/tmp/foo.pb \
-  --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
   --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \
+  --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
   --output_arrays=InceptionV1/InceptionV1/Mixed_3b/concat_v2
 ```
 
-Note that the final representation in TensorFlow Lite FlatBuffers tends to have
+Note that the final representation in TensorFlow Lite models tends to have
 coarser granularity than the very fine granularity of the TensorFlow GraphDef
 representation. For example, while a fully-connected layer is typically
-represented as at least four separate ops in TensorFlow GraphDef (Reshape,
-MatMul, BiasAdd, Relu...), it is typically represented as a single "fused" op
-(FullyConnected) in the converter's optimized representation and in the final
-on-device representation. As the level of granularity gets coarser, some
-intermediate arrays (say, the array between the MatMul and the BiasAdd in the
-TensorFlow GraphDef) are dropped.
+represented as at least four separate operations in TensorFlow GraphDef
+(Reshape, MatMul, BiasAdd, Relu...), it is typically represented as a single
+"fused" op (FullyConnected) in the converter's optimized representation and in
+the final on-device representation. As the level of granularity gets coarser,
+some intermediate arrays (say, the array between the MatMul and the BiasAdd in
+the TensorFlow GraphDef) are dropped.
 
 When specifying intermediate arrays as `--input_arrays` and `--output_arrays`,
 it is desirable (and often required) to specify arrays that are meant to survive
-in the final form of the graph, after fusing. These are typically the outputs of
+in the final form of the model, after fusing. These are typically the outputs of
 activation functions (since everything in each layer until the activation
 function tends to get fused).
 
-## Logging
+## Visualization <a name="visualization"></a>
 
-
-## Graph visualizations
-
-The converter can export a graph to the Graphviz Dot format for easy
+The converter can export a model to the Graphviz Dot format for easy
 visualization using either the `--output_format` flag or the
 `--dump_graphviz_dir` flag. The subsections below outline the use cases for
 each.
@@ -229,21 +214,20 @@
 ### Using `--output_format=GRAPHVIZ_DOT` <a name="using_output_format_graphviz_dot"></a>
 
 The first way to get a Graphviz rendering is to pass `GRAPHVIZ_DOT` into
-`--output_format`. This results in a plausible visualization of the graph. This
+`--output_format`. This results in a plausible visualization of the model. This
 reduces the requirements that exist during conversion from a TensorFlow GraphDef
-to a TensorFlow Lite FlatBuffer. This may be useful if the conversion to TFLite
-is failing.
+to a TensorFlow Lite model. This may be useful if the conversion to TFLite is
+failing.
 
 ```
-curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
-  | tar xzv -C /tmp
 tflite_convert \
   --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
   --output_file=/tmp/foo.dot \
   --output_format=GRAPHVIZ_DOT \
-  --input_shape=1,128,128,3 \
   --input_arrays=input \
+  --input_shape=1,128,128,3 \
   --output_arrays=MobilenetV1/Predictions/Reshape_1
+
 ```
 
 The resulting `.dot` file can be rendered into a PDF as follows:
@@ -267,12 +251,10 @@
 The second way to get a Graphviz rendering is to pass the `--dump_graphviz_dir`
 flag, specifying a destination directory to dump Graphviz rendering to. Unlike
 the previous approach, this one retains the original output format. This
-provides a visualization of the actual graph resulting from a specific
+provides a visualization of the actual model resulting from a specific
 conversion process.
 
 ```
-curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
-  | tar xzv -C /tmp
 tflite_convert \
   --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
   --output_file=/tmp/foo.tflite \
@@ -283,14 +265,14 @@
 
 This generates a few files in the destination directory. The two most important
 files are `toco_AT_IMPORT.dot` and `/tmp/toco_AFTER_TRANSFORMATIONS.dot`.
-`toco_AT_IMPORT.dot` represents the original graph containing only the
+`toco_AT_IMPORT.dot` represents the original model containing only the
 transformations done at import time. This tends to be a complex visualization
 with limited information about each node. It is useful in situations where a
 conversion command fails.
 
-`toco_AFTER_TRANSFORMATIONS.dot` represents the graph after all transformations
+`toco_AFTER_TRANSFORMATIONS.dot` represents the model after all transformations
 were applied to it, just before it is exported. Typically, this is a much
-smaller graph with more information about each node.
+smaller model with more information about each node.
 
 As before, these can be rendered to PDFs:
 
@@ -316,15 +298,15 @@
 <tr><td>before</td><td>after</td></tr>
 </table>
 
-### Graph "video" logging
+### Video logging
 
 When `--dump_graphviz_dir` is used, one may additionally pass
-`--dump_graphviz_video`. This causes a graph visualization to be dumped after
-each individual graph transformation, resulting in thousands of files.
+`--dump_graphviz_video`. This causes a model visualization to be dumped after
+each individual model transformation, resulting in thousands of files.
 Typically, one would then bisect into these files to understand when a given
-change was introduced in the graph.
+change was introduced in the model.
 
-### Legend for the graph visualizations <a name="graphviz_legend"></a>
+### Legend for the Visualizations <a name="graphviz_legend"></a>
 
 *   Operators are red square boxes with the following hues of red:
     *   Most operators are
diff --git a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md
index 8cca69d..826bb7a 100644
--- a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md
+++ b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md
@@ -1,42 +1,41 @@
 # Converter command line reference
 
 This page is complete reference of command-line flags used by the TensorFlow
-Lite Converter's command line starting from TensorFlow 1.9 up until the most
-recent build of TensorFlow.
+Lite Converter's command line tool.
 
 ## High-level flags
 
 The following high level flags specify the details of the input and output
 files. The flag `--output_file` is always required. Additionally, either
-`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required.
+`--saved_model_dir`, `--keras_model_file` or `--graph_def_file` is required.
 
 *   `--output_file`. Type: string. Specifies the full path of the output file.
-*   `--graph_def_file`. Type: string. Specifies the full path of the input
-    GraphDef file frozen using
-    [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
 *   `--saved_model_dir`. Type: string. Specifies the full path to the directory
     containing the SavedModel.
 *   `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file
     containing the tf.keras model.
+*   `--graph_def_file`. Type: string. Specifies the full path of the input
+    GraphDef file frozen using
+    [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
 *   `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of
     the output file. Allowed values:
-    *   `TFLITE`: TensorFlow Lite FlatBuffer format.
+    *   `TFLITE`: TensorFlow Lite model format.
     *   `GRAPHVIZ_DOT`: GraphViz `.dot` format containing a visualization of the
         graph after graph transformations.
         *   Note that passing `GRAPHVIZ_DOT` to `--output_format` leads to loss
-            of TFLite specific transformations. Therefore, the resulting
-            visualization may not reflect the final set of graph
-            transformations. To get a final visualization with all graph
-            transformations use `--dump_graphviz_dir` instead.
+            of TFLite specific transformations. To get a final visualization
+            with all graph transformations use `--dump_graphviz_dir` instead.
 
 The following flags specify optional parameters when using SavedModels.
 
-*   `--saved_model_tag_set`. Type: string. Default:
-    [kSavedModelTagServe](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h).
+*   `--saved_model_tag_set`. Type: string. Default: "serve" (for more options,
+    refer to
+    [tag_constants.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h)).
     Specifies a comma-separated set of tags identifying the MetaGraphDef within
     the SavedModel to analyze. All tags in the tag set must be specified.
-*   `--saved_model_signature_key`. Type: string. Default:
-    `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`.
+*   `--saved_model_signature_key`. Type: string. Default: "serving_default" (for
+    more options, refer to
+    [tf.compat.v1.saved_model.signature_constants](https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/signature_constants)).
     Specifies the key identifying the SignatureDef containing inputs and
     outputs.
 
@@ -46,9 +45,9 @@
 file.
 
 *   `--input_arrays`. Type: comma-separated list of strings. Specifies the list
-    of names of input activation tensors.
+    of names of input tensors.
 *   `--output_arrays`. Type: comma-separated list of strings. Specifies the list
-    of names of output activation tensors.
+    of names of output tensors.
 
 The following flags define properties of the input tensors. Each item in the
 `--input_arrays` flag should correspond to each item in the following flags
@@ -56,8 +55,7 @@
 
 *   `--input_shapes`. Type: colon-separated list of comma-separated lists of
     integers. Each comma-separated list of integers gives the shape of one of
-    the input arrays specified in
-    [TensorFlow convention](https://www.tensorflow.org/guide/tensors#shape).
+    the input arrays.
     *   Example: `--input_shapes=1,60,80,3` for a typical vision model means a
         batch size of 1, an input image height of 60, an input image width of
         80, and an input image depth of 3 (representing RGB channels).
@@ -65,24 +63,24 @@
         has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
 *   `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
     These specify the (de-)quantization parameters of the input array, when it
-    is quantized. This is only needed if `inference_input_type` is
+    is quantized. This is only needed if `inference_input_type` is `INT8` or
     `QUANTIZED_UINT8`.
     *   The meaning of `mean_values` and `std_dev_values` is as follows: each
         quantized value in the quantized input array will be interpreted as a
         mathematical real number (i.e. as an input activation value) according
         to the following formula:
-        *   `real_value = (quantized_input_value - mean_value) / std_dev_value`.
+        *   `real_value = (quantized_value - mean_value) / std_dev_value`.
     *   When performing float inference (`--inference_type=FLOAT`) on a
         quantized input, the quantized input would be immediately dequantized by
         the inference code according to the above formula, before proceeding
         with float inference.
-    *   When performing quantized inference
-        (`--inference_type=QUANTIZED_UINT8`), no dequantization is performed by
-        the inference code. However, the quantization parameters of all arrays,
-        including those of the input arrays as specified by `mean_value` and
-        `std_dev_value`, determine the fixed-point multipliers used in the
-        quantized inference code. `mean_value` must be an integer when
-        performing quantized inference.
+    *   When performing quantized inference (`inference_type`
+        is`INT8`or`QUANTIZED_UINT8`), no dequantization is performed by the
+        inference code. However, the quantization parameters of all arrays,
+        including those of the input arrays as specified
+        by`mean_value`and`std_dev_value`, determine the fixed-point multipliers
+        used in the quantized inference code.`mean_value` must be an integer
+        when performing quantized inference.
 
 ## Transformation flags
 
@@ -92,7 +90,7 @@
 
 *   `--inference_type`. Type: string. Default: `FLOAT`. Data type of all
     real-number arrays in the output file except for input arrays (defined by
-    `--inference_input_type`). Must be `{FLOAT, QUANTIZED_UINT8}`.
+    `--inference_input_type`). Must be `{FLOAT, INT8, QUANTIZED_UINT8}`.
 
     This flag only impacts real-number arrays including float and quantized
     arrays. This excludes all other data types including plain integer arrays
@@ -101,6 +99,9 @@
     *   If `FLOAT`, then real-numbers arrays will be of type float in the output
         file. If they were quantized in the input file, then they get
         dequantized.
+    *   If `INT8`, then real-numbers arrays will be quantized as int8 in the
+        output file. If they were float in the input file, then they get
+        quantized.
     *   If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as
         uint8 in the output file. If they were float in the input file, then
         they get quantized.
@@ -109,7 +110,8 @@
     array in the output file. By default the `--inference_type` is used as type
     of all of the input arrays. Flag is primarily intended for generating a
     float-point graph with a quantized input array. A Dequantized operator is
-    added immediately after the input array. Must be `{FLOAT, QUANTIZED_UINT8}`.
+    added immediately after the input array. Must be `{FLOAT, INT8,
+    QUANTIZED_UINT8}`.
 
     The flag is typically used for vision models taking a bitmap as input but
     requiring floating-point inference. For such image models, the uint8 input
diff --git a/tensorflow/lite/g3doc/r1/convert/index.md b/tensorflow/lite/g3doc/r1/convert/index.md
index 4080689..7a4e8c7 100644
--- a/tensorflow/lite/g3doc/r1/convert/index.md
+++ b/tensorflow/lite/g3doc/r1/convert/index.md
@@ -1,48 +1,48 @@
 # TensorFlow Lite converter
 
-The TensorFlow Lite converter is used to convert TensorFlow models into an
-optimized [FlatBuffer](https://google.github.io/flatbuffers/) format, so that
-they can be used by the TensorFlow Lite interpreter.
+The TensorFlow Lite converter takes a TensorFlow model and generates a
+TensorFlow Lite model, which is an optimized
+[FlatBuffer](https://google.github.io/flatbuffers/) (identified by the `.tflite`
+file extension).
 
 Note: This page contains documentation on the converter API for TensorFlow 1.x.
 The API for TensorFlow 2.0 is available
 [here](https://www.tensorflow.org/lite/convert/).
 
-## FlatBuffers
+## Options
+
+The TensorFlow Lite Converter can be used in two ways:
+
+*   [Python API](python_api.md) (**recommended**): Using the Python API makes it
+    easier to convert models as part of a model development pipeline and helps
+    mitigate compatibility issues early on.
+*   [Command line](cmdline_examples.md)
+
+## Workflow
+
+### Why use the 'FlatBuffer' format?
 
 FlatBuffer is an efficient open-source cross-platform serialization library. It
-is similar to
-[protocol buffers](https://developers.google.com/protocol-buffers), with the
-distinction that FlatBuffers do not need a parsing/unpacking step to a secondary
-representation before data can be accessed, avoiding per-object memory
-allocation. The code footprint of FlatBuffers is an order of magnitude smaller
-than protocol buffers.
+is similar to [protocol buffers](https://developers.google.com/protocol-buffers)
+used in the TensorFlow model format, with the distinction that FlatBuffers do
+not need a parsing/unpacking step to a secondary representation before data can
+be accessed, avoiding per-object memory allocation. The code footprint of
+FlatBuffers is an order of magnitude smaller than protocol buffers.
 
-## From model training to device deployment
-
-The TensorFlow Lite converter generates a TensorFlow Lite
-[FlatBuffer](https://google.github.io/flatbuffers/) file (`.tflite`) from a
-TensorFlow model.
+### Convert the model
 
 The converter supports the following input formats:
 
 *   [SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators)
-*   Frozen `GraphDef`: Models generated by
+*   `tf.keras` H5 models.
+*   Frozen `GraphDef` models generated using
     [freeze_graph.py](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py).
-*   `tf.keras` HDF5 models.
-*   Any model taken from a `tf.Session` (Python API only).
+*   `tf.Session` models (Python API only).
 
-The TensorFlow Lite `FlatBuffer` file is then deployed to a client device, and
-the TensorFlow Lite interpreter uses the compressed model for on-device
-inference. This conversion process is shown in the diagram below:
+### Run inference
+
+The TensorFlow Lite model is then deployed to a client device, and the
+TensorFlow Lite interpreter uses the compressed model for on-device inference.
+This conversion process is shown in the diagram below:
 
 ![TFLite converter workflow](../images/convert/workflow.svg)
-
-## Options
-
-The TensorFlow Lite Converter can be used from either of these two options:
-
-*   [Python](python_api.md) (**Preferred**): Using the Python API makes it
-    easier to convert models as part of a model development pipeline, and helps
-    mitigate [compatibility](../tf_ops_compatibility.md) issues early on.
-*   [Command line](cmdline_examples.md)
diff --git a/tensorflow/lite/g3doc/r1/convert/python_api.md b/tensorflow/lite/g3doc/r1/convert/python_api.md
index 30d6575..0eca3a4 100644
--- a/tensorflow/lite/g3doc/r1/convert/python_api.md
+++ b/tensorflow/lite/g3doc/r1/convert/python_api.md
@@ -1,119 +1,67 @@
 # Converter Python API guide
 
 This page describes how to convert TensorFlow models into the TensorFlow Lite
-format using the TensorFlow Lite Converter Python API.
+format using the
+[`tf.compat.v1.lite.TFLiteConverter`](https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter)
+Python API. It provides the following class methods based on the original format
+of the model:
 
-If you're looking for information about how to run a TensorFlow Lite model,
-see [TensorFlow Lite inference](../guide/inference.md).
+*   `tf.compat.v1.lite.TFLiteConverter.from_saved_model()`: Converts a
+    [SavedModel](https://www.tensorflow.org/guide/saved_model).
+*   `tf.compat.v1.lite.TFLiteConverter.from_keras_model_file()`: Converts a
+    [Keras](https://www.tensorflow.org/guide/keras/overview) model file.
+*   `tf.compat.v1.lite.TFLiteConverter.from_session()`: Converts a GraphDef from
+    a session.
+*   `tf.compat.v1.lite.TFLiteConverter.from_frozen_graph()`: Converts a Frozen
+    GraphDef from a file. If you have checkpoints, then first convert it to a
+    Frozen GraphDef file and then use this API as shown [here](#checkpoints).
 
-Note: This page describes the converter in the TensorFlow nightly release,
-installed using `pip install tf-nightly`. For docs describing older versions
-reference ["Converting models from TensorFlow 1.12"](#pre_tensorflow_1.12).
-
-
-## High-level overview
-
-While the TensorFlow Lite Converter can be used from the command line, it is
-often convenient to use in a Python script as part of the model development
-pipeline. This allows you to know early that you are designing a model that can
-be targeted to devices with mobile.
-
-## API
-
-The API for converting TensorFlow models to TensorFlow Lite is
-`tf.lite.TFLiteConverter`, which provides class methods based on the original
-format of the model. For example, `TFLiteConverter.from_session()` is available
-for GraphDefs, `TFLiteConverter.from_saved_model()` is available for
-SavedModels, and `TFLiteConverter.from_keras_model_file()` is available for
-`tf.Keras` files.
-
-Example usages for simple float-point models are shown in
-[Basic Examples](#basic). Examples usages for more complex models is shown in
-[Complex Examples](#complex).
+In the following sections, we discuss [basic examples](#basic) and
+[complex examples](#complex).
 
 ## Basic examples <a name="basic"></a>
 
-The following section shows examples of how to convert a basic float-point model
-from each of the supported data formats into a TensorFlow Lite FlatBuffers.
+The following section shows examples of how to convert a basic model from each
+of the supported model formats into a TensorFlow Lite model.
 
-### Exporting a GraphDef from tf.Session <a name="basic_graphdef_sess"></a>
+### Convert a SavedModel <a name="basic_savedmodel"></a>
 
-The following example shows how to convert a TensorFlow GraphDef into a
-TensorFlow Lite FlatBuffer from a `tf.Session` object.
+The following example shows how to convert a
+[SavedModel](https://www.tensorflow.org/guide/saved_model) into a TensorFlow
+Lite model.
 
 ```python
 import tensorflow as tf
 
-img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
-val = img + var
-out = tf.identity(val, name="out")
-
-with tf.Session() as sess:
-  sess.run(tf.global_variables_initializer())
-  converter = tf.lite.TFLiteConverter.from_session(sess, [img], [out])
-  tflite_model = converter.convert()
-  open("converted_model.tflite", "wb").write(tflite_model)
-```
-
-### Exporting a GraphDef from file <a name="basic_graphdef_file"></a>
-
-The following example shows how to convert a TensorFlow GraphDef into a
-TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and
-`.pbtxt` files are accepted.
-
-The example uses
-[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz).
-The function only supports GraphDefs frozen using
-[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
-
-```python
-import tensorflow as tf
-
-graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb"
-input_arrays = ["input"]
-output_arrays = ["MobilenetV1/Predictions/Softmax"]
-
-converter = tf.lite.TFLiteConverter.from_frozen_graph(
-  graph_def_file, input_arrays, output_arrays)
+# Convert the model.
+converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(saved_model_dir)
 tflite_model = converter.convert()
-open("converted_model.tflite", "wb").write(tflite_model)
+
+# Save the model.
+with open('model.tflite', 'wb') as f:
+  f.write(tflite_model)
 ```
 
-### Exporting a SavedModel <a name="basic_savedmodel"></a>
+### Convert a Keras model file <a name="basic_keras_file"></a>
 
-The following example shows how to convert a SavedModel into a TensorFlow Lite
-FlatBuffer.
+The following example shows how to convert a
+[Keras](https://www.tensorflow.org/guide/keras/overview) model file into a
+TensorFlow Lite model.
 
 ```python
 import tensorflow as tf
 
-converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+# Convert the model.
+converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file('keras_model.h5')
 tflite_model = converter.convert()
-open("converted_model.tflite", "wb").write(tflite_model)
+
+# Save the model.
+with open('model.tflite', 'wb') as f:
+  f.write(tflite_model)
 ```
 
-For more complex SavedModels, the optional parameters that can be passed into
-`TFLiteConverter.from_saved_model()` are `input_arrays`, `input_shapes`,
-`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are
-available by running `help(tf.lite.TFLiteConverter)`.
-
-### Exporting a tf.keras File <a name="basic_keras_file"></a>
-
-The following example shows how to convert a `tf.keras` model into a TensorFlow
-Lite FlatBuffer. This example requires
-[`h5py`](http://docs.h5py.org/en/latest/build.html) to be installed.
-
-```python
-import tensorflow as tf
-
-converter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5")
-tflite_model = converter.convert()
-open("converted_model.tflite", "wb").write(tflite_model)
-```
-
-The `tf.keras` file must contain both the model and the weights. A comprehensive
-example including model construction can be seen below.
+The Keras file contains both the model and the weights. A comprehensive example
+is given below.
 
 ```python
 import numpy as np
@@ -134,61 +82,139 @@
 model.train_on_batch(x, y)
 model.predict(x)
 
-# Save tf.keras model in HDF5 format.
-keras_file = "keras_model.h5"
+# Save tf.keras model in H5 format.
+keras_file = 'keras_model.h5'
 tf.keras.models.save_model(model, keras_file)
 
-# Convert to TensorFlow Lite model.
-converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
+# Convert the model.
+converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(keras_file)
 tflite_model = converter.convert()
-open("converted_model.tflite", "wb").write(tflite_model)
+
+# Save the model.
+with open('model.tflite', 'wb') as f:
+  f.write(tflite_model)
 ```
 
-## Complex examples <a name="complex"></a>
+### Convert a GraphDef from a session <a name="basic_graphdef_sess"></a>
 
-For models where the default value of the attributes is not sufficient, the
-attribute's values should be set before calling `convert()`. In order to call
-any constants use `tf.lite.constants.<CONSTANT_NAME>` as seen below with
-`QUANTIZED_UINT8`. Run `help(tf.lite.TFLiteConverter)` in the Python
-terminal for detailed documentation on the attributes.
-
-Although the examples are demonstrated on GraphDefs containing only constants.
-The same logic can be applied irrespective of the input data format.
-
-### Exporting a quantized GraphDef <a name="complex_quant"></a>
-
-The following example shows how to convert a quantized model into a TensorFlow
-Lite FlatBuffer.
+The following example shows how to convert a GraphDef from a `tf.Session` object
+into a TensorFlow Lite model .
 
 ```python
 import tensorflow as tf
 
-img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
-val = img + const
-out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output")
+img = tf.placeholder(name='img', dtype=tf.float32, shape=(1, 64, 64, 3))
+var = tf.get_variable('weights', dtype=tf.float32, shape=(1, 64, 64, 3))
+val = img + var
+out = tf.identity(val, name='out')
 
 with tf.Session() as sess:
-  converter = tf.lite.TFLiteConverter.from_session(sess, [img], [out])
-  converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
-  input_arrays = converter.get_input_arrays()
-  converter.quantized_input_stats = {input_arrays[0] : (0., 1.)}  # mean, std_dev
+  sess.run(tf.global_variables_initializer())
+
+  # Convert the model.
+  converter = tf.compat.v1.lite.TFLiteConverter.from_session(sess, [img], [out])
   tflite_model = converter.convert()
-  open("converted_model.tflite", "wb").write(tflite_model)
+
+  # Save the model.
+  with open('model.tflite', 'wb') as f:
+    f.write(tflite_model)
 ```
 
+### Convert a Frozen GraphDef from file <a name="basic_graphdef_file"></a>
 
-## Additional instructions
+The following example shows how to convert a Frozen GraphDef (or a frozen
+graph), usually generated using the
+[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)
+script, into a TensorFlow Lite model.
 
-### Build from source code <a name="latest_package"></a>
+The example uses
+[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz).
 
-In order to run the latest version of the TensorFlow Lite Converter Python API,
-either install the nightly build with
-[pip](https://www.tensorflow.org/install/pip) (recommended) or
-[Docker](https://www.tensorflow.org/install/docker), or
-[build the pip package from source](https://www.tensorflow.org/install/source).
+```python
+import tensorflow as tf
 
-### Converting models from TensorFlow 1.12 <a name="pre_tensorflow_1.12"></a>
+# Convert the model.
+converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
+    graph_def_file='/path/to/mobilenet_v1_1.0_224/frozen_graph.pb',
+                    # both `.pb` and `.pbtxt` files are accepted.
+    input_arrays=['input'],
+    input_shapes={'input' : [1, 224, 224,3]},
+    output_arrays=['MobilenetV1/Predictions/Softmax']
+)
+tflite_model = converter.convert()
+
+# Save the model.
+with open('model.tflite', 'wb') as f:
+  f.write(tflite_model)
+```
+
+#### Convert checkpoints <a name="checkpoints"></a>
+
+1.  Convert checkpoints to a Frozen GraphDef as follows
+    (*[reference](https://laid.delanover.com/how-to-freeze-a-graph-in-tensorflow/)*):
+
+    *   Install [bazel](https://docs.bazel.build/versions/master/install.html)
+    *   Clone the TensorFlow repository: `git clone
+        https://github.com/tensorflow/tensorflow.git`
+    *   Build freeze graph tool: `bazel build
+        tensorflow/python/tools:freeze_graph`
+        *   The directory from which you run this should contain a file named
+            'WORKSPACE'.
+        *   If you're running on Ubuntu 16.04 OS and face issues, update the
+            command to `bazel build -c opt --copt=-msse4.1 --copt=-msse4.2
+            tensorflow/python/tools:freeze_graph`
+    *   Run freeze graph tool: `bazel run tensorflow/python/tools:freeze_graph
+        --input_graph=/path/to/graph.pbtxt --input_binary=false
+        --input_checkpoint=/path/to/model.ckpt-00010
+        --output_graph=/path/to/frozen_graph.pb
+        --output_node_names=name1,name2.....`
+        *   If you have an input `*.pb` file instead of `*.pbtxt`, then replace
+            `--input_graph=/path/to/graph.pbtxt --input_binary=false` with
+            `--input_graph=/path/to/graph.pb`
+        *   You can find the output names by exploring the graph using
+            [Netron](https://github.com/lutzroeder/netron) or
+            [summarize graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#inspecting-graphs).
+
+2.  Now [convert the Frozen GraphDef file](#basic_graphdef_file) to a TensorFlow
+    Lite model as shown in the example above.
+
+## Complex examples <a name="complex"></a>
+
+For models where the default value of the attributes is not sufficient, the
+attribute's values should be set before calling `convert()`. Run
+`help(tf.compat.v1.lite.TFLiteConverter)` in the Python terminal for detailed
+documentation on the attributes.
+
+### Convert a quantize aware trained model <a name="complex_quant"></a>
+
+The following example shows how to convert a quantize aware trained model into a
+TensorFlow Lite model.
+
+The example uses
+[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz).
+
+```python
+import tensorflow as tf
+
+# Convert the model.
+converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
+    graph_def_file='/path/to/mobilenet_v1_1.0_224/frozen_graph.pb',
+    input_arrays=['input'],
+    input_shapes={'input' : [1, 224, 224,3]},
+    output_arrays=['MobilenetV1/Predictions/Softmax'],
+)
+converter.quantized_input_stats = {'input' : (0., 1.)}  # mean, std_dev (input range is [-1, 1])
+converter.inference_type = tf.int8 # this is the recommended type.
+# converter.inference_input_type=tf.uint8 # optional
+# converter.inference_output_type=tf.uint8 # optional
+tflite_model = converter.convert()
+
+# Save the model.
+with open('quantized_model.tflite', 'wb') as f:
+  f.write(tflite_model)
+```
+
+## Convert models from TensorFlow 1.12 <a name="pre_tensorflow_1.12"></a>
 
 Reference the following table to convert TensorFlow models to TensorFlow Lite in
 and before TensorFlow 1.12. Run `help()` to get details of each API.
diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h
index f27a17d..4561e13 100644
--- a/tensorflow/lite/interpreter.h
+++ b/tensorflow/lite/interpreter.h
@@ -538,8 +538,8 @@
   // for the tensor, it can no longer be reset to the TFLite arena memory.
   //
   // Parameters should satisfy the following conditions:
-  // 1. tensor->allocation_type == kTfLiteArenaRw
-  //    In general, this is true for all non-constants such as I/O tensors.
+  // 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent
+  //    In general, this is true for I/O tensors & variable tensors.
   // 2. allocation->data has the appropriate permissions for runtime access
   //    (Read-only for inputs, Read-Write for others), and outlives Interpreter.
   // 3. allocation->bytes >= tensor->bytes.
diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc
index 0765f00..7743bc7 100644
--- a/tensorflow/lite/interpreter_builder.cc
+++ b/tensorflow/lite/interpreter_builder.cc
@@ -27,15 +27,12 @@
 #include "tensorflow/lite/core/api/error_reporter.h"
 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/profiling/platform_profiler.h"
 #include "tensorflow/lite/schema/schema_generated.h"
 #include "tensorflow/lite/shared_library.h"
 #include "tensorflow/lite/util.h"
 #include "tensorflow/lite/version.h"
 
-#if defined(TFLITE_ENABLE_DEFAULT_PROFILER)
-#include "tensorflow/lite/profiling/platform_profiler.h"
-#endif
-
 // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
 #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
 #if !defined(__ANDROID__) || __ANDROID_API__ >= 28
@@ -630,9 +627,7 @@
     (*interpreter)->AddSubgraphs(subgraphs->size() - 1);
   }
 
-#if defined(TFLITE_ENABLE_DEFAULT_PROFILER)
-  (*interpreter)->SetProfiler(tflite::profiling::CreatePlatformProfiler());
-#endif
+  (*interpreter)->SetProfiler(tflite::profiling::MaybeCreatePlatformProfiler());
 
   for (int subgraph_index = 0; subgraph_index < subgraphs->size();
        ++subgraph_index) {
diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc
index a35799a..bd0f724 100644
--- a/tensorflow/lite/interpreter_test.cc
+++ b/tensorflow/lite/interpreter_test.cc
@@ -1484,9 +1484,9 @@
   void SetUp() override {
     // Simple model with two custom ops that add 2 float tensors each.
     interpreter_.reset(new Interpreter);
-    interpreter_->AddTensors(5);
+    interpreter_->AddTensors(7);
     interpreter_->SetInputs({0, 1});
-    interpreter_->SetOutputs({3, 4});
+    interpreter_->SetOutputs({3, 4, 6});
     TfLiteQuantizationParams quant;
     interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
                                                quant);
@@ -1498,6 +1498,10 @@
                                                quant);
     interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
                                                quant);
+    interpreter_->SetTensorParametersReadWrite(5, kTfLiteFloat32, "", {3},
+                                               quant, /*is_variable=*/true);
+    interpreter_->SetTensorParametersReadWrite(6, kTfLiteFloat32, "", {3},
+                                               quant);
     auto* add_reg = ops::builtin::Register_ADD();
     TfLiteAddParams* builtin_data0 =
         reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
@@ -1505,15 +1509,21 @@
         reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
     TfLiteAddParams* builtin_data2 =
         reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+    TfLiteAddParams* builtin_data3 =
+        reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
     builtin_data0->activation = kTfLiteActNone;
     builtin_data1->activation = kTfLiteActNone;
     builtin_data2->activation = kTfLiteActNone;
+    builtin_data3->activation = kTfLiteActNone;
     interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, builtin_data0,
                                         add_reg);
     interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, builtin_data1,
                                         add_reg);
     interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, builtin_data2,
                                         add_reg);
+    interpreter_->AddNodeWithParameters({0, 5}, {6}, nullptr, 0, builtin_data3,
+                                        add_reg);
+    interpreter_->SetVariables({5});
   }
 
   void AssignCustomAllocForTensor(int tensor_idx, int required_alignment) {
@@ -1526,18 +1536,22 @@
 
   void VerifyInvoke() {
     std::vector<float> input = {1.0f, 2.0f, 3.0f};
+    std::vector<float> variable = {0.0f, 1.0f, 2.0f};
     std::vector<float> expected_output = {2.0f, 4.0f, 6.0f};
-    TfLiteTensor* tensor = interpreter_->tensor(interpreter_->outputs()[0]);
 
     // typed_tensor<...> should work irrespective of custom alloc, since it
-    // accesses tensor.data.
+    // accesses output_tensor.data.
+    memcpy(interpreter_->typed_tensor<float>(interpreter_->variables()[0]),
+           variable.data(), 3 * sizeof(float));
     memcpy(interpreter_->typed_tensor<float>(0), input.data(),
            3 * sizeof(float));
     memcpy(interpreter_->typed_tensor<float>(1), input.data(),
            3 * sizeof(float));
     ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
+    TfLiteTensor* output_tensor =
+        interpreter_->tensor(interpreter_->outputs()[0]);
     for (int i = 0; i < 3; ++i) {
-      EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i;
+      EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i;
     }
   }
 
@@ -1649,6 +1663,36 @@
   VerifyInvoke();
 }
 
+// Ensure that custom allocs work for tensors on persistent arena as well.
+TEST_F(TestCustomAllocation, CustomAlloc_VariableTensor) {
+  // Set custom allocation for one input tensor.
+  AssignCustomAllocForTensor(interpreter_->variables()[0],
+                             /*required_alignment=*/kDefaultTensorAlignment);
+
+  ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+  VerifyInvoke();
+
+  AssignCustomAllocForTensor(interpreter_->variables()[0],
+                             /*required_alignment=*/kDefaultTensorAlignment);
+  ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+
+  std::vector<float> input = {2.0f, 3.0f, 4.0f};
+  std::vector<float> variable = {1.0f, 2.0f, 3.0f};
+  std::vector<float> expected_output = {3.0f, 5.0f, 7.0f};
+  memcpy(interpreter_->typed_tensor<float>(interpreter_->variables()[0]),
+         variable.data(), 3 * sizeof(float));
+  memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float));
+  memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float));
+  ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
+
+  // expected_output = input + variable
+  TfLiteTensor* output_tensor =
+      interpreter_->tensor(interpreter_->outputs()[2]);
+  for (int i = 0; i < 3; ++i) {
+    EXPECT_EQ(output_tensor->data.f[i], expected_output[i]) << i;
+  }
+}
+
 TEST_F(TestCustomAllocation, ResizeTensorsWithoutEnoughMemory) {
   // Set custom allocations for all input tensors.
   AssignCustomAllocForTensor(interpreter_->inputs()[0],
diff --git a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
index bb62b44..4665f59 100644
--- a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
+++ b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -326,58 +326,55 @@
     final int deviceIndex = deviceView.getCheckedItemPosition();
     final int numThreads = np.getValue();
 
-    backgroundHandler.post(() -> {
-      if (modelIndex == currentModel && deviceIndex == currentDevice
+    backgroundHandler.post(
+        () -> {
+          if (modelIndex == currentModel
+              && deviceIndex == currentDevice
               && numThreads == currentNumThreads) {
-        return;
-      }
-      currentModel = modelIndex;
-      currentDevice = deviceIndex;
-      currentNumThreads = numThreads;
+            return;
+          }
+          currentModel = modelIndex;
+          currentDevice = deviceIndex;
+          currentNumThreads = numThreads;
 
-      // Disable classifier while updating
-      if (classifier != null) {
-        classifier.close();
-        classifier = null;
-      }
+          // Disable classifier while updating
+          if (classifier != null) {
+            classifier.close();
+            classifier = null;
+          }
 
-      // Lookup names of parameters.
-      String model = modelStrings.get(modelIndex);
-      String device = deviceStrings.get(deviceIndex);
+          // Lookup names of parameters.
+          String model = modelStrings.get(modelIndex);
+          String device = deviceStrings.get(deviceIndex);
 
-      Log.i(TAG, "Changing model to " + model + " device " + device);
+          Log.i(TAG, "Changing model to " + model + " device " + device);
 
-      // Try to load model.
-      try {
-        if (model.equals(mobilenetV1Quant)) {
-          classifier = new ImageClassifierQuantizedMobileNet(getActivity());
-        } else if (model.equals(mobilenetV1Float)) {
-          classifier = new ImageClassifierFloatMobileNet(getActivity());
-        } else {
-          showToast("Failed to load model");
-        }
-      } catch (IOException e) {
-        Log.d(TAG, "Failed to load", e);
-        classifier = null;
-      }
+          // Try to load model.
+          try {
+            if (model.equals(mobilenetV1Quant)) {
+              classifier = new ImageClassifierQuantizedMobileNet(getActivity());
+            } else if (model.equals(mobilenetV1Float)) {
+              classifier = new ImageClassifierFloatMobileNet(getActivity());
+            } else {
+              showToast("Failed to load model");
+            }
+          } catch (IOException e) {
+            Log.d(TAG, "Failed to load", e);
+            classifier = null;
+          }
 
-      // Customize the interpreter to the type of device we want to use.
-      if (classifier == null) {
-        return;
-      }
-      classifier.setNumThreads(numThreads);
-      if (device.equals(cpu)) {
-      } else if (device.equals(gpu)) {
-        if (model.equals(mobilenetV1Quant)) {
-          showToast("gpu requires float model.");
-          classifier = null;
-        } else {
-          classifier.useGpu();
-        }
-      } else if (device.equals(nnApi)) {
-        classifier.useNNAPI();
-      }
-    });
+          // Customize the interpreter to the type of device we want to use.
+          if (classifier == null) {
+            return;
+          }
+          classifier.setNumThreads(numThreads);
+          if (device.equals(cpu)) {
+          } else if (device.equals(gpu)) {
+            classifier.useGpu();
+          } else if (device.equals(nnApi)) {
+            classifier.useNNAPI();
+          }
+        });
   }
 
   /** Connect the buttons to their event handler. */
diff --git a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
index 2e483d8..21149c9 100644
--- a/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
+++ b/tensorflow/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -172,7 +172,10 @@
 
   public void useGpu() {
     if (gpuDelegate == null) {
-      gpuDelegate = new GpuDelegate();
+      GpuDelegate.Options options = new GpuDelegate.Options();
+      options.setQuantizedModelsAllowed(true);
+
+      gpuDelegate = new GpuDelegate(options);
       tfliteOptions.addDelegate(gpuDelegate);
       recreateInterpreter();
     }
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 59afc0c..ba0f569 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -152,7 +152,7 @@
      * <ul>
      *   <li>Startup time and resize time may increase.
      *   <li>Baseline memory consumption may increase.
-     *   <li>Compatibility with other delegates (e.g., GPU) has not been fully validated.
+     *   <li>May be ignored if another delegate (eg NNAPI) have been applied.
      *   <li>Quantized models will not see any benefit.
      * </ul>
      *
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 2d1844f..fc0857f 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -367,8 +367,14 @@
     }
     tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
         xnnpack_create(&options), xnnpack_delete);
-    if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) !=
-        kTfLiteOk) {
+    auto delegation_status =
+        interpreter->ModifyGraphWithDelegate(std::move(delegate));
+    // kTfLiteApplicationError occurs in cases where delegation fails but
+    // the runtime is invokable (eg. another delegate has already been applied).
+    // We don't throw an Exception in that case.
+    // TODO(b/166483905): Add support for multiple delegates when model allows.
+    if (delegation_status != kTfLiteOk &&
+        delegation_status != kTfLiteApplicationError) {
       ThrowException(env, kIllegalArgumentException,
                      "Internal error: Failed to apply XNNPACK delegate: %s",
                      error_reporter->CachedErrorMessage());
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java
index 45d66e2..fc9038c 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java
@@ -57,6 +57,25 @@
   }
 
   @Test
+  public void testInterpreterWithNnApiAndXNNPack() throws Exception {
+    Interpreter.Options options = new Interpreter.Options();
+    options.setUseXNNPACK(true);
+
+    try (NnApiDelegate delegate = new NnApiDelegate();
+        Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
+      float[] oneD = {1.23f, 6.54f, 7.81f};
+      float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+      float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+      float[][][][] fourD = {threeD, threeD};
+      float[][][][] parsedOutputs = new float[2][8][8][3];
+      interpreter.run(fourD, parsedOutputs);
+      float[] outputOneD = parsedOutputs[0][0][0];
+      float[] expected = {3.69f, 19.62f, 23.43f};
+      assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+    }
+  }
+
+  @Test
   public void testInterpreterWithNnApiAllowFp16() throws Exception {
     Interpreter.Options options = new Interpreter.Options();
     NnApiDelegate.Options nnApiOptions = new NnApiDelegate.Options();
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index 9a672df..45a1c97 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -2,6 +2,7 @@
 load("//tensorflow/lite/micro:build_def.bzl", "micro_copts")
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined")
 load("//tensorflow:tensorflow.bzl", "tf_opts_nortti_if_android")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = [
@@ -429,6 +430,7 @@
     hdrs = [
         "op_macros.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     deps = ["//tensorflow/lite/micro:debug_log"],
 )
@@ -441,6 +443,7 @@
     hdrs = [
         "kernel_util.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + micro_copts(),
     deps = [
         "//tensorflow/lite/c:common",
@@ -697,20 +700,22 @@
 
 cc_library(
     name = "custom_ops",
-    srcs = ["rfft2d.cc"],
+    srcs = [
+        "complex_support.cc",
+        "cumsum.cc",
+        "rfft2d.cc",
+    ],
     hdrs = ["custom_ops_register.h"],
     copts = tflite_copts(),
     deps = [
         ":kernel_util",
-        ":op_macros",
-        "//tensorflow/lite:context",
         "//tensorflow/lite/c:common",
-        "//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
-        "//tensorflow/lite/kernels/internal:kernel_utils",
+        "//tensorflow/lite/kernels/internal:optimized_base",
         "//tensorflow/lite/kernels/internal:tensor",
         "//tensorflow/lite/kernels/internal:types",
         "//third_party/fft2d:fft2d_headers",
         "@fft2d",
+        "@flatbuffers",
         "@ruy//ruy/profiler:instrumentation",
     ],
 )
@@ -2288,4 +2293,34 @@
     ],
 )
 
+cc_test(
+    name = "complex_support_test",
+    srcs = ["complex_support_test.cc"],
+    deps = [
+        ":custom_ops",
+        ":test_main",
+        ":test_util",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/lite/testing:util",
+        "@com_google_googletest//:gtest",
+        "@flatbuffers",
+    ],
+)
+
+cc_test(
+    name = "cumsum_test",
+    srcs = ["cumsum_test.cc"],
+    deps = [
+        ":custom_ops",
+        ":test_main",
+        ":test_util",
+        "//tensorflow/lite:framework",
+        "//tensorflow/lite/schema:schema_fbs",
+        "//tensorflow/lite/testing:util",
+        "@com_google_googletest//:gtest",
+        "@flatbuffers",
+    ],
+)
+
 tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
diff --git a/tensorflow/lite/kernels/complex_support.cc b/tensorflow/lite/kernels/complex_support.cc
new file mode 100644
index 0000000..7f5886c
--- /dev/null
+++ b/tensorflow/lite/kernels/complex_support.cc
@@ -0,0 +1,146 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+// TODO(b/165735381): Promote this op to builtin-op when we can add new builtin
+// ops.
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace complex {
+
+static const int kInputTensor = 0;
+static const int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+  TF_LITE_ENSURE(context, input->type == kTfLiteComplex64 ||
+                              input->type == kTfLiteComplex128);
+
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  if (input->type == kTfLiteComplex64) {
+    TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+  } else {
+    TF_LITE_ENSURE(context, output->type = kTfLiteFloat64);
+  }
+
+  TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
+  return context->ResizeTensor(context, output, output_shape);
+}
+
+template <typename T, typename ExtractF>
+void ExtractData(const TfLiteTensor* input, ExtractF extract_func,
+                 TfLiteTensor* output) {
+  const std::complex<T>* input_data = GetTensorData<std::complex<T>>(input);
+  T* output_data = GetTensorData<T>(output);
+  const int input_size = NumElements(input);
+  for (int i = 0; i < input_size; ++i) {
+    *output_data++ = extract_func(*input_data++);
+  }
+}
+
+TfLiteStatus EvalReal(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  switch (input->type) {
+    case kTfLiteComplex64: {
+      ExtractData<float>(
+          input,
+          static_cast<float (*)(const std::complex<float>&)>(std::real<float>),
+          output);
+      break;
+    }
+    case kTfLiteComplex128: {
+      ExtractData<double>(input,
+                          static_cast<double (*)(const std::complex<double>&)>(
+                              std::real<double>),
+                          output);
+      break;
+    }
+    default: {
+      TF_LITE_KERNEL_LOG(context,
+                         "Unsupported input type, Real op only supports "
+                         "complex input, but got: ",
+                         TfLiteTypeGetName(input->type));
+      return kTfLiteError;
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  switch (input->type) {
+    case kTfLiteComplex64: {
+      ExtractData<float>(
+          input,
+          static_cast<float (*)(const std::complex<float>&)>(std::imag<float>),
+          output);
+      break;
+    }
+    case kTfLiteComplex128: {
+      ExtractData<double>(input,
+                          static_cast<double (*)(const std::complex<double>&)>(
+                              std::imag<double>),
+                          output);
+      break;
+    }
+    default: {
+      TF_LITE_KERNEL_LOG(context,
+                         "Unsupported input type, Imag op only supports "
+                         "complex input, but got: ",
+                         TfLiteTypeGetName(input->type));
+      return kTfLiteError;
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+}  // namespace complex
+
+TfLiteRegistration* Register_REAL() {
+  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+                                 complex::Prepare, complex::EvalReal};
+  return &r;
+}
+
+TfLiteRegistration* Register_IMAG() {
+  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+                                 complex::Prepare, complex::EvalImag};
+  return &r;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/complex_support_test.cc b/tensorflow/lite/kernels/complex_support_test.cc
new file mode 100644
index 0000000..cb60345
--- /dev/null
+++ b/tensorflow/lite/kernels/complex_support_test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <complex>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/custom_ops_register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/testing/util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_REAL();
+TfLiteRegistration* Register_IMAG();
+
+namespace {
+
+template <typename T>
+class RealOpModel : public SingleOpModel {
+ public:
+  RealOpModel(const TensorData& input, const TensorData& output) {
+    input_ = AddInput(input);
+
+    output_ = AddOutput(output);
+
+    const std::vector<uint8_t> custom_option;
+    SetCustomOp("Real", custom_option, Register_REAL);
+
+    BuildInterpreter({GetShape(input_)});
+  }
+
+  int input() { return input_; }
+
+  std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+
+ private:
+  int input_;
+  int output_;
+};
+
+TEST(RealOpTest, SimpleFloatTest) {
+  RealOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
+                       {TensorType_FLOAT32, {}});
+
+  m.PopulateTensor<std::complex<float>>(m.input(), {{75, 0},
+                                                    {-6, -1},
+                                                    {9, 0},
+                                                    {-10, 5},
+                                                    {-3, 2},
+                                                    {-6, 11},
+                                                    {0, 0},
+                                                    {22.1, 33.3}});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
+                                 {75, -6, 9, -10, -3, -6, 0, 22.1f})));
+}
+
+TEST(RealOpTest, SimpleDoubleTest) {
+  RealOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
+                        {TensorType_FLOAT64, {}});
+
+  m.PopulateTensor<std::complex<double>>(m.input(), {{75, 0},
+                                                     {-6, -1},
+                                                     {9, 0},
+                                                     {-10, 5},
+                                                     {-3, 2},
+                                                     {-6, 11},
+                                                     {0, 0},
+                                                     {22.1, 33.3}});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
+                                 {75, -6, 9, -10, -3, -6, 0, 22.1f})));
+}
+
+template <typename T>
+class ImagOpModel : public SingleOpModel {
+ public:
+  ImagOpModel(const TensorData& input, const TensorData& output) {
+    input_ = AddInput(input);
+
+    output_ = AddOutput(output);
+
+    const std::vector<uint8_t> custom_option;
+    SetCustomOp("Imag", custom_option, Register_IMAG);
+
+    BuildInterpreter({GetShape(input_)});
+  }
+
+  int input() { return input_; }
+
+  std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+
+ private:
+  int input_;
+  int output_;
+};
+
+TEST(ImagOpTest, SimpleFloatTest) {
+  ImagOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
+                       {TensorType_FLOAT32, {}});
+
+  m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
+                                                    {-6, -1},
+                                                    {9, 3.5},
+                                                    {-10, 5},
+                                                    {-3, 2},
+                                                    {-6, 11},
+                                                    {0, 0},
+                                                    {22.1, 33.3}});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
+                                 {7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
+}
+
+TEST(ImagOpTest, SimpleDoubleTest) {
+  ImagOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
+                        {TensorType_FLOAT64, {}});
+
+  m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
+                                                     {-6, -1},
+                                                     {9, 3.5},
+                                                     {-10, 5},
+                                                     {-3, 2},
+                                                     {-6, 11},
+                                                     {0, 0},
+                                                     {22.1, 33.3}});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
+                                 {7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
+}
+
+}  // namespace
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/kernels/cumsum.cc b/tensorflow/lite/kernels/cumsum.cc
new file mode 100644
index 0000000..173de09
--- /dev/null
+++ b/tensorflow/lite/kernels/cumsum.cc
@@ -0,0 +1,125 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+// TODO(b/161933288): Promote this op to builtin-op when we can add new builtin
+// ops.
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace cumsum {
+
+typedef struct {
+  bool exclusive;
+  bool reverse;
+} TfLiteCumsumParams;
+
+static const int kInputTensor = 0;
+static const int kAxisTensor = 1;
+static const int kOutputTensor = 0;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  auto* data = new TfLiteCumsumParams;
+  const uint8_t* buffer_data = reinterpret_cast<const uint8_t*>(buffer);
+
+  const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_data, length).AsMap();
+  data->exclusive = m["exclusive"].AsBool();
+  data->reverse = m["reverse"].AsBool();
+  return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+  delete reinterpret_cast<TfLiteCumsumParams*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
+
+  TF_LITE_ENSURE(context,
+                 input->type == kTfLiteInt32 || input->type == kTfLiteFloat32);
+  TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32);
+
+  TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
+
+  TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
+
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
+  return context->ResizeTensor(context, output, output_shape);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
+
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  auto* params = reinterpret_cast<TfLiteCumsumParams*>(node->user_data);
+
+  int axis = *GetTensorData<int>(axis_tensor);
+  if (axis < 0) axis += NumDimensions(input);
+
+  if (axis < 0 || axis >= NumDimensions(input)) {
+    TF_LITE_KERNEL_LOG(context, "Invalid axis: ", axis);
+    return kTfLiteError;
+  }
+
+  switch (input->type) {
+    case kTfLiteInt32: {
+      optimized_ops::CumSum(GetTensorData<int>(input), GetTensorShape(input),
+                            axis, params->exclusive, params->reverse,
+                            GetTensorData<int>(output));
+      break;
+    }
+    case kTfLiteFloat32: {
+      optimized_ops::CumSum(GetTensorData<float>(input), GetTensorShape(input),
+                            axis, params->exclusive, params->reverse,
+                            GetTensorData<float>(output));
+      break;
+    }
+    default: {
+      TF_LITE_KERNEL_LOG(
+          context,
+          "Unsupported input type, cumsum only supports int32 & float32.");
+      return kTfLiteError;
+    }
+  }
+
+  return kTfLiteOk;
+}
+
+}  // namespace cumsum
+
+TfLiteRegistration* Register_CUMSUM() {
+  static TfLiteRegistration r = {cumsum::Init, cumsum::Free, cumsum::Prepare,
+                                 cumsum::Eval};
+  return &r;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/cumsum_test.cc b/tensorflow/lite/kernels/cumsum_test.cc
new file mode 100644
index 0000000..092defd
--- /dev/null
+++ b/tensorflow/lite/kernels/cumsum_test.cc
@@ -0,0 +1,148 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/custom_ops_register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/testing/util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_CUMSUM();
+
+namespace {
+
+template <typename T>
+class CumsumOpModel : public SingleOpModel {
+ public:
+  CumsumOpModel(const TensorData& input, const TensorData& output,
+                bool exclusive, bool reverse) {
+    input_ = AddInput(input);
+    axis_ = AddInput({TensorType_INT32, {1}});
+
+    output_ = AddOutput(output);
+
+    flexbuffers::Builder fbb;
+    fbb.Map([&]() {
+      fbb.Bool("exclusive", exclusive);
+      fbb.Bool("reverse", reverse);
+    });
+    fbb.Finish();
+    SetCustomOp("Cumsum", fbb.GetBuffer(), Register_CUMSUM);
+
+    BuildInterpreter({GetShape(input_), GetShape(axis_)});
+  }
+
+  int input() { return input_; }
+  int axis() { return axis_; }
+
+  std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+
+ private:
+  int input_;
+  int axis_;
+  int output_;
+};
+
+TEST(CumsumOpTest, SimpleIntTest) {
+  CumsumOpModel<int32_t> m({TensorType_INT32, {2, 4}}, {TensorType_INT32, {}},
+                           false, false);
+
+  m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6, 7, 8});
+  m.PopulateTensor<int>(m.axis(), {1});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(),
+              testing::ElementsAreArray({1, 3, 6, 10, 5, 11, 18, 26}));
+}
+
+TEST(CumsumOpTest, SimpleIntAxis0Test) {
+  CumsumOpModel<int32_t> m({TensorType_INT32, {2, 4}}, {TensorType_INT32, {}},
+                           false, false);
+
+  m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6, 7, 8});
+  m.PopulateTensor<int>(m.axis(), {0});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(),
+              testing::ElementsAreArray({1, 2, 3, 4, 6, 8, 10, 12}));
+}
+
+TEST(CumsumOpTest, Simple1DIntTest) {
+  CumsumOpModel<int32_t> m({TensorType_INT32, {8}}, {TensorType_INT32, {}},
+                           false, false);
+
+  m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6, 7, 8});
+  m.PopulateTensor<int>(m.axis(), {0});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(),
+              testing::ElementsAreArray({1, 3, 6, 10, 15, 21, 28, 36}));
+}
+
+TEST(CumsumOpTest, SimpleIntReverseTest) {
+  CumsumOpModel<int32_t> m({TensorType_INT32, {2, 4}}, {TensorType_INT32, {}},
+                           false, true);
+
+  m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6, 7, 8});
+  m.PopulateTensor<int>(m.axis(), {1});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(),
+              testing::ElementsAreArray({10, 9, 7, 4, 26, 21, 15, 8}));
+}
+
+TEST(CumsumOpTest, SimpleIntExclusiveTest) {
+  CumsumOpModel<int32_t> m({TensorType_INT32, {2, 4}}, {TensorType_INT32, {}},
+                           true, false);
+
+  m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6, 7, 8});
+  m.PopulateTensor<int>(m.axis(), {1});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(),
+              testing::ElementsAreArray({0, 1, 3, 6, 0, 5, 11, 18}));
+}
+
+TEST(CumsumOpTest, SimpleFloatTest) {
+  CumsumOpModel<float> m({TensorType_FLOAT32, {2, 4}}, {TensorType_FLOAT32, {}},
+                         false, false);
+
+  m.PopulateTensor<float>(m.input(), {1, 2, 3, 4, 5, 6, 7, 8});
+  m.PopulateTensor<int>(m.axis(), {1});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(
+                                 ArrayFloatNear({1, 3, 6, 10, 5, 11, 18, 26})));
+}
+
+}  // namespace
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/lite/kernels/custom_ops_register.h b/tensorflow/lite/kernels/custom_ops_register.h
index 3abc893..659091f 100644
--- a/tensorflow/lite/kernels/custom_ops_register.h
+++ b/tensorflow/lite/kernels/custom_ops_register.h
@@ -21,12 +21,16 @@
 namespace ops {
 namespace custom {
 
+TfLiteRegistration* Register_CUMSUM();
 TfLiteRegistration* Register_RFFT2D();
 TfLiteRegistration* Register_HASHTABLE();
 TfLiteRegistration* Register_HASHTABLE_FIND();
 TfLiteRegistration* Register_HASHTABLE_IMPORT();
 TfLiteRegistration* Register_HASHTABLE_SIZE();
-}
+TfLiteRegistration* Register_REAL();
+TfLiteRegistration* Register_IMAG();
+
+}  // namespace custom
 }  // namespace ops
 }  // namespace tflite
 
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 2588d4f..b109e8e 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -3,6 +3,7 @@
 load("//tensorflow/lite:build_def.bzl", "tflite_copts")
 load("//tensorflow/lite/micro:build_def.bzl", "micro_copts")
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = [
@@ -47,6 +48,7 @@
 cc_library(
     name = "compatibility",
     hdrs = ["compatibility.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     deps = [
         "//tensorflow/lite/kernels:op_macros",
@@ -56,6 +58,7 @@
 cc_library(
     name = "types",
     hdrs = ["types.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     deps = [
         ":compatibility",
@@ -372,6 +375,7 @@
         "max.h",
         "min.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
 )
 
@@ -379,6 +383,7 @@
     name = "quantization_util",
     srcs = ["quantization_util.cc"],
     hdrs = ["quantization_util.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts() + micro_copts(),
     deps = [
         ":compatibility",
@@ -459,6 +464,7 @@
         "reference/integer_ops/fully_connected.h",
         "reference/integer_ops/l2normalization.h",
         "reference/integer_ops/logistic.h",
+        "reference/integer_ops/mean.h",
         "reference/integer_ops/mul.h",
         "reference/integer_ops/pooling.h",
         "reference/integer_ops/tanh.h",
@@ -487,7 +493,6 @@
         "//conditions:default": [
             "reference/integer_ops/dequantize.h",
             "reference/integer_ops/log_softmax.h",
-            "reference/integer_ops/mean.h",
             "reference/integer_ops/transpose_conv.h",
             "reference/reference_ops.h",
             "reference/string_comparisons.h",
@@ -608,6 +613,7 @@
             "tensor.h",
         ],
     }),
+    compatible_with = get_compatible_with_portable(),
     copts = tflite_copts(),
     deps = [
         ":types",
diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h
index 66a2d97..92bb01a 100644
--- a/tensorflow/lite/kernels/internal/common.h
+++ b/tensorflow/lite/kernels/internal/common.h
@@ -263,6 +263,30 @@
       std::min(std::max(TfLiteRound(func(max) * 32768.0), -32768.0), 32767.0);
 }
 
+// generate INT16 LUT for function(), e.g., table exp(x) and 1/(1+x) used in
+// softmax
+inline void gen_lut(const std::function<float(float)>& func, float min,
+                    float max, int16_t* table, const int num) {
+  // size of table should equal to num + 1
+  // last element only for slope calculation
+  float step = (max - min) / (num - 1);
+  float half_step = step / 2.0f;
+  for (int i = 0; i < num - 1; i++) {
+    float sample_val = TfLiteRound(func(min + i * step) * 32768.0f);
+    float midpoint_interp_val =
+        TfLiteRound((func(min + (i + 1) * step) * 32768.0f +
+                     TfLiteRound(func(min + i * step) * 32768.0f)) /
+                    2.0f);
+    float midpoint_val =
+        TfLiteRound(func(min + i * step + half_step) * 32768.0f);
+    float midpoint_err = midpoint_interp_val - midpoint_val;
+    float bias = TfLiteRound(midpoint_err / 2.0f);
+    table[i] = std::min(std::max(sample_val - bias, -32768.0f), 32767.0f);
+  }
+  table[num - 1] = std::min(
+      std::max(TfLiteRound(func(max) * 32768.0f), -32768.0f), 32767.0f);
+}
+
 // int16_t func table lookup, e.g., lookup exp() and 1/(1+x) used in softmax
 inline int16_t generic_int16_table_lookup(int16_t value, const int16_t* lut) {
   // 512 base value, lut[513] only for calculate slope
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index c505ee8..5290a08 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -8132,6 +8132,46 @@
                           MinimumElementwise, MinimumScalarBroadcast);
 }
 
+template <typename T>
+void CumsumImpl(const T* input_data, const RuntimeShape& shape, int axis,
+                bool exclusive, bool reverse, T* output_data) {
+  Eigen::array<Eigen::DenseIndex, 3> dims = {1, 1, 1};
+
+  for (int i = 0; i < axis; ++i) {
+    dims[0] *= shape.Dims(i);
+  }
+  dims[1] = shape.Dims(axis);
+  for (int i = axis + 1; i < shape.DimensionsCount(); ++i) {
+    dims[2] *= shape.Dims(i);
+  }
+
+  typedef Eigen::TensorMap<
+      Eigen::Tensor<const T, 3, Eigen::RowMajor, Eigen::DenseIndex>,
+      Eigen::Aligned>
+      ConstTensor;
+  typedef Eigen::TensorMap<
+      Eigen::Tensor<T, 3, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned>
+      Tensor;
+  ConstTensor input(input_data, dims);
+  Tensor output(output_data, dims);
+
+  if (reverse) {
+    Eigen::array<bool, 3> reverse_idx = {false, true, false};
+    output =
+        input.reverse(reverse_idx).cumsum(1, exclusive).reverse(reverse_idx);
+  } else {
+    output = input.cumsum(1, exclusive);
+  }
+}
+
+template <typename T>
+void CumSum(const T* input_data, const RuntimeShape& shape, int axis,
+            bool exclusive, bool reverse, T* output_data) {
+  const int dim = shape.DimensionsCount();
+  TFLITE_DCHECK_GE(dim, 1);
+  CumsumImpl<T>(input_data, shape, axis, exclusive, reverse, output_data);
+}
+
 }  // namespace optimized_ops
 }  // namespace tflite
 
diff --git a/tensorflow/lite/kernels/internal/reference/conv.h b/tensorflow/lite/kernels/internal/reference/conv.h
index d4bf46a..b912ac1 100644
--- a/tensorflow/lite/kernels/internal/reference/conv.h
+++ b/tensorflow/lite/kernels/internal/reference/conv.h
@@ -59,28 +59,31 @@
   const int output_width = output_shape.Dims(2);
   for (int batch = 0; batch < batches; ++batch) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
+      const int in_y_origin = (out_y * stride_height) - pad_height;
       for (int out_x = 0; out_x < output_width; ++out_x) {
+        const int in_x_origin = (out_x * stride_width) - pad_width;
         for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
-          const int in_x_origin = (out_x * stride_width) - pad_width;
-          const int in_y_origin = (out_y * stride_height) - pad_height;
           float total = 0.f;
           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+            const int in_y = in_y_origin + dilation_height_factor * filter_y;
             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+              const int in_x = in_x_origin + dilation_width_factor * filter_x;
+
+              // Zero padding by omitting the areas outside the image.
+              const bool is_point_inside_image =
+                  (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+                  (in_y < input_height);
+
+              if (!is_point_inside_image) {
+                continue;
+              }
+
               for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
-                const int in_x = in_x_origin + dilation_width_factor * filter_x;
-                const int in_y =
-                    in_y_origin + dilation_height_factor * filter_y;
-                // If the location is outside the bounds of the input image,
-                // use zero as a default value.
-                if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
-                    (in_y < input_height)) {
-                  float input_value = input_data[Offset(
-                      input_shape, batch, in_y, in_x, in_channel)];
-                  float filter_value =
-                      filter_data[Offset(filter_shape, out_channel, filter_y,
-                                         filter_x, in_channel)];
-                  total += (input_value * filter_value);
-                }
+                float input_value = input_data[Offset(input_shape, batch, in_y,
+                                                      in_x, in_channel)];
+                float filter_value = filter_data[Offset(
+                    filter_shape, out_channel, filter_y, filter_x, in_channel)];
+                total += (input_value * filter_value);
               }
             }
           }
@@ -139,29 +142,32 @@
   const int output_width = output_shape.Dims(2);
   for (int batch = 0; batch < batches; ++batch) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
+      const int in_y_origin = (out_y * stride_height) - pad_height;
       for (int out_x = 0; out_x < output_width; ++out_x) {
+        const int in_x_origin = (out_x * stride_width) - pad_width;
         for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
-          const int in_x_origin = (out_x * stride_width) - pad_width;
-          const int in_y_origin = (out_y * stride_height) - pad_height;
           int32_t acc = 0;
           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+            const int in_y = in_y_origin + dilation_height_factor * filter_y;
             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+              const int in_x = in_x_origin + dilation_width_factor * filter_x;
+
+              // Zero padding by omitting the areas outside the image.
+              const bool is_point_inside_image =
+                  (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+                  (in_y < input_height);
+
+              if (!is_point_inside_image) {
+                continue;
+              }
+
               for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
-                const int in_x = in_x_origin + dilation_width_factor * filter_x;
-                const int in_y =
-                    in_y_origin + dilation_height_factor * filter_y;
-                // If the location is outside the bounds of the input image,
-                // use zero as a default value.
-                if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
-                    (in_y < input_height)) {
-                  int32_t input_val = input_data[Offset(
-                      input_shape, batch, in_y, in_x, in_channel)];
-                  int32_t filter_val =
-                      filter_data[Offset(filter_shape, out_channel, filter_y,
-                                         filter_x, in_channel)];
-                  acc +=
-                      (filter_val + filter_offset) * (input_val + input_offset);
-                }
+                int32_t input_val = input_data[Offset(input_shape, batch, in_y,
+                                                      in_x, in_channel)];
+                int32_t filter_val = filter_data[Offset(
+                    filter_shape, out_channel, filter_y, filter_x, in_channel)];
+                acc +=
+                    (filter_val + filter_offset) * (input_val + input_offset);
               }
             }
           }
@@ -258,5 +264,4 @@
 }  // namespace reference_ops
 }  // namespace tflite
 
-
 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV_H_
diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
index f4bcb2b..3e9cd0c 100644
--- a/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
+++ b/tensorflow/lite/kernels/internal/reference/integer_ops/conv.h
@@ -63,45 +63,47 @@
   const int output_width = output_shape.Dims(2);
   for (int batch = 0; batch < batches; ++batch) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
+      const int in_y_origin = (out_y * stride_height) - pad_height;
       for (int out_x = 0; out_x < output_width; ++out_x) {
+        const int in_x_origin = (out_x * stride_width) - pad_width;
         for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
-          const int in_x_origin = (out_x * stride_width) - pad_width;
-          const int in_y_origin = (out_y * stride_height) - pad_height;
           int32_t acc = 0;
           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+            const int in_y = in_y_origin + dilation_height_factor * filter_y;
             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+              const int in_x = in_x_origin + dilation_width_factor * filter_x;
+
+              // Zero padding by omitting the areas outside the image.
+              const bool is_point_inside_image =
+                  (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+                  (in_y < input_height);
+
+              if (!is_point_inside_image) {
+                continue;
+              }
+
               for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
-                const int in_x = in_x_origin + dilation_width_factor * filter_x;
-                const int in_y =
-                    in_y_origin + dilation_height_factor * filter_y;
-                // Zero padding by omitting the areas outside the image.
-                const bool is_point_inside_image =
-                    (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
-                    (in_y < input_height);
-                if (is_point_inside_image) {
-                  int32_t input_val = input_data[Offset(
-                      input_shape, batch, in_y, in_x, in_channel)];
-                  int32_t filter_val =
-                      filter_data[Offset(filter_shape, out_channel, filter_y,
-                                         filter_x, in_channel)];
-                  // Accumulate with 32 bits accumulator.
-                  // In the nudging process during model quantization, we force
-                  // real value of 0.0 be represented by a quantized value. This
-                  // guarantees that the input_offset is a int8_t, even though
-                  // it is represented using int32_t. int32_t += int8_t *
-                  // (int8_t - int8_t) so the highest value we can get from each
-                  // accumulation is [-127, 127] * ([-128, 127] -
-                  // [-128, 127]), which is [-32512, 32512]. log2(32512)
-                  // = 14.98, which means we can accumulate at least 2^16
-                  // multiplications without overflow. The accumulator is
-                  // applied to a filter so the accumulation logic will hold as
-                  // long as the filter size (filter_y * filter_x * in_channel)
-                  // does not exceed 2^16, which is the case in all the models
-                  // we have seen so far.
-                  // TODO(jianlijianli): Add a check to make sure the
-                  // accumulator depth is smaller than 2^16.
-                  acc += filter_val * (input_val + input_offset);
-                }
+                int32_t input_val = input_data[Offset(input_shape, batch, in_y,
+                                                      in_x, in_channel)];
+                int32_t filter_val = filter_data[Offset(
+                    filter_shape, out_channel, filter_y, filter_x, in_channel)];
+                // Accumulate with 32 bits accumulator.
+                // In the nudging process during model quantization, we force
+                // real value of 0.0 be represented by a quantized value. This
+                // guarantees that the input_offset is a int8_t, even though
+                // it is represented using int32_t. int32_t += int8_t *
+                // (int8_t - int8_t) so the highest value we can get from each
+                // accumulation is [-127, 127] * ([-128, 127] -
+                // [-128, 127]), which is [-32512, 32512]. log2(32512)
+                // = 14.98, which means we can accumulate at least 2^16
+                // multiplications without overflow. The accumulator is
+                // applied to a filter so the accumulation logic will hold as
+                // long as the filter size (filter_y * filter_x * in_channel)
+                // does not exceed 2^16, which is the case in all the models
+                // we have seen so far.
+                // TODO(jianlijianli): Add a check to make sure the
+                // accumulator depth is smaller than 2^16.
+                acc += filter_val * (input_val + input_offset);
               }
             }
           }
@@ -164,35 +166,37 @@
   const int output_width = output_shape.Dims(2);
   for (int batch = 0; batch < batches; ++batch) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
+      const int in_y_origin = (out_y * stride_height) - pad_height;
       for (int out_x = 0; out_x < output_width; ++out_x) {
+        const int in_x_origin = (out_x * stride_width) - pad_width;
         for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
-          const int in_x_origin = (out_x * stride_width) - pad_width;
-          const int in_y_origin = (out_y * stride_height) - pad_height;
           std::int64_t acc = 0;
           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+            const int in_y = in_y_origin + dilation_height_factor * filter_y;
             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+              const int in_x = in_x_origin + dilation_width_factor * filter_x;
+
+              // Zero padding by omitting the areas outside the image.
+              const bool is_point_inside_image =
+                  (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+                  (in_y < input_height);
+
+              if (!is_point_inside_image) {
+                continue;
+              }
+
               for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
-                const int in_x = in_x_origin + dilation_width_factor * filter_x;
-                const int in_y =
-                    in_y_origin + dilation_height_factor * filter_y;
-                // Zero padding by omitting the areas outside the image.
-                const bool is_point_inside_image =
-                    (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
-                    (in_y < input_height);
-                if (is_point_inside_image) {
-                  int32_t input_val = input_data[Offset(
-                      input_shape, batch, in_y, in_x, in_channel)];
-                  int32_t filter_val =
-                      filter_data[Offset(filter_shape, out_channel, filter_y,
-                                         filter_x, in_channel)];
-                  // Accumulate with 64 bits accumulator.
-                  // int64_t += int8_t * int16_t so the highest value we can
-                  // get from each accumulation is [-127, 127] * ([-32768,
-                  // 32767] -
-                  // [-32768, 32767]), which is [-8322945, 8322945].
-                  // log2(8322945) = 22.99.
-                  acc += filter_val * input_val;
-                }
+                int32_t input_val = input_data[Offset(input_shape, batch, in_y,
+                                                      in_x, in_channel)];
+                int32_t filter_val = filter_data[Offset(
+                    filter_shape, out_channel, filter_y, filter_x, in_channel)];
+                // Accumulate with 64 bits accumulator.
+                // int64_t += int8_t * int16_t so the highest value we can
+                // get from each accumulation is [-127, 127] * ([-32768,
+                // 32767] -
+                // [-32768, 32767]), which is [-8322945, 8322945].
+                // log2(8322945) = 22.99.
+                acc += filter_val * input_val;
               }
             }
           }
diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h
index 1e29f8c..bd48427 100644
--- a/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h
+++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h
@@ -23,9 +23,9 @@
 template <typename integer_type>
 inline void Mean(const tflite::MeanParams& op_params, int32_t multiplier,
                  int32_t shift, const RuntimeShape& unextended_input_shape,
-                 const integer_type* input_data, int32 input_zero_point,
+                 const integer_type* input_data, int32_t input_zero_point,
                  const RuntimeShape& unextended_output_shape,
-                 integer_type* output_data, int32 output_zero_point) {
+                 integer_type* output_data, int32_t output_zero_point) {
   // Current implementation only supports dimension equals 4 and simultaneous
   // reduction over width and height.
   TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
@@ -53,7 +53,7 @@
 
   for (int out_b = 0; out_b < output_batch; ++out_b) {
     for (int out_d = 0; out_d < output_depth; ++out_d) {
-      int32 acc = 0;
+      int32_t acc = 0;
       for (int in_h = 0; in_h < input_height; ++in_h) {
         for (int in_w = 0; in_w < input_width; ++in_w) {
           acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)] -
diff --git a/tensorflow/lite/kernels/internal/reference/reduce.h b/tensorflow/lite/kernels/internal/reference/reduce.h
index 597d015..7953b43 100644
--- a/tensorflow/lite/kernels/internal/reference/reduce.h
+++ b/tensorflow/lite/kernels/internal/reference/reduce.h
@@ -186,11 +186,11 @@
   }
 
   // Calculate mean by dividing output_data by num of aggregated element.
-  U num_elements_in_axis = 1;
+  size_t num_elements_in_axis = 1;
   for (int idx = 0; idx < num_resolved_axis; ++idx) {
     size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
     // Overflow prevention.
-    if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
+    if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
       return false;
     }
     num_elements_in_axis *= current;
@@ -359,11 +359,11 @@
   }
 
   // Calculate mean by dividing output_data by num of aggregated element.
-  U num_elements_in_axis = 1;
+  size_t num_elements_in_axis = 1;
   for (int idx = 0; idx < num_resolved_axis; ++idx) {
     size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
     // Overflow prevention.
-    if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
+    if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
       return false;
     }
     num_elements_in_axis *= current;
diff --git a/tensorflow/lite/kernels/internal/reference/softmax.h b/tensorflow/lite/kernels/internal/reference/softmax.h
index b035b43..ee5bd1e 100644
--- a/tensorflow/lite/kernels/internal/reference/softmax.h
+++ b/tensorflow/lite/kernels/internal/reference/softmax.h
@@ -49,15 +49,15 @@
     // Compute sum.
     float sum = 0.f;
     for (int c = 0; c < depth; ++c) {
-      sum += std::exp((input_data[i * depth + c] - max) *
-                      static_cast<float>(params.beta));
+      const float exp_c = std::exp((input_data[i * depth + c] - max) *
+                                   static_cast<float>(params.beta));
+      output_data[i * depth + c] = exp_c;
+      sum += exp_c;
     }
 
     // Compute result.
     for (int c = 0; c < depth; ++c) {
-      output_data[i * depth + c] = std::exp((input_data[i * depth + c] - max) *
-                                            static_cast<float>(params.beta)) /
-                                   sum;
+      output_data[i * depth + c] = output_data[i * depth + c] / sum;
     }
   }
 }
diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h
index 9db742d..0164f82 100644
--- a/tensorflow/lite/kernels/internal/types.h
+++ b/tensorflow/lite/kernels/internal/types.h
@@ -1044,7 +1044,9 @@
   int32_t zero_point;
   float scale;
   float* table;
+  // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
   int16_t* exp_lut;
+  // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
   int16_t* one_over_one_plus_x_lut;
   uint8_t* uint8_table1;
   uint8_t* uint8_table2;
diff --git a/tensorflow/lite/kernels/range.cc b/tensorflow/lite/kernels/range.cc
index fe67d05..71ee420 100644
--- a/tensorflow/lite/kernels/range.cc
+++ b/tensorflow/lite/kernels/range.cc
@@ -41,8 +41,8 @@
 TfLiteStatus GetSize(TfLiteContext* context, T start, T limit, T delta,
                      int* size) {
   TF_LITE_ENSURE(context, !std::equal_to<T>()(delta, 0));
-  TF_LITE_ENSURE(context,
-                 (start > limit && delta < 0) || (start < limit && delta > 0));
+  TF_LITE_ENSURE(
+      context, (start >= limit && delta < 0) || (start <= limit && delta > 0));
   *size =
       (std::is_integral<T>::value
            ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
diff --git a/tensorflow/lite/kernels/range_test.cc b/tensorflow/lite/kernels/range_test.cc
index 52f7231..bb11d15 100644
--- a/tensorflow/lite/kernels/range_test.cc
+++ b/tensorflow/lite/kernels/range_test.cc
@@ -112,5 +112,15 @@
   EXPECT_THAT(model.GetOutput(), ElementsAre(10, 7, 4));
 }
 
+TEST(RangeOpModel, EmptyOutput) {
+  RangeOpModel<int32_t> model(TensorType_INT32);
+  model.PopulateTensor<int32_t>(model.start(), {0});
+  model.PopulateTensor<int32_t>(model.limit(), {0});
+  model.PopulateTensor<int32_t>(model.delta(), {1});
+  model.Invoke();
+  EXPECT_THAT(model.GetOutputShape(), ElementsAre(0));
+  EXPECT_THAT(model.GetOutput(), ElementsAre());
+}
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc
index c3a4025..cfa5557 100644
--- a/tensorflow/lite/kernels/test_util.cc
+++ b/tensorflow/lite/kernels/test_util.cc
@@ -145,10 +145,22 @@
       CustomOptionsFormat_FLEXBUFFERS));
 }
 
+void SingleOpModel::AllocateAndDelegate(bool apply_delegate) {
+  CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
+      << "Cannot allocate tensors";
+  interpreter_->ResetVariableTensors();
+
+  // In some rare cases a test may need to postpone modifying the graph with
+  // a delegate, e.g. if tensors are not fully specified. In such cases the
+  // test has to explicitly call ApplyDelegate() when necessary.
+  if (apply_delegate) ApplyDelegate();
+}
+
 void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
                                      int num_threads,
                                      bool allow_fp32_relax_to_fp16,
-                                     bool apply_delegate) {
+                                     bool apply_delegate,
+                                     bool allocate_and_delegate) {
   auto opcodes = builder_.CreateVector(opcodes_);
   auto operators = builder_.CreateVector(operators_);
   auto tensors = builder_.CreateVector(tensors_);
@@ -190,14 +202,9 @@
 
   interpreter_->SetAllowFp16PrecisionForFp32(allow_fp32_relax_to_fp16);
 
-  CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
-      << "Cannot allocate tensors";
-  interpreter_->ResetVariableTensors();
-
-  // In some rare cases a test may need to postpone modifying the graph with
-  // a delegate, e.g. if tensors are not fully specified. In such cases the
-  // test has to explicitly call ApplyDelegate() when necessary.
-  if (apply_delegate) ApplyDelegate();
+  if (allocate_and_delegate) {
+    AllocateAndDelegate(apply_delegate);
+  }
 }
 
 TfLiteStatus SingleOpModel::ApplyDelegate() {
@@ -229,7 +236,7 @@
     std::vector<std::vector<int>> input_shapes) {
   BuildInterpreter(input_shapes, /*num_threads=*/-1,
                    /*allow_fp32_relax_to_fp16=*/false,
-                   /*apply_delegate=*/true);
+                   /*apply_delegate=*/true, /*allocate_and_delegate=*/true);
 }
 
 // static
diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h
index 3e13335..7785b6c 100644
--- a/tensorflow/lite/kernels/test_util.h
+++ b/tensorflow/lite/kernels/test_util.h
@@ -392,11 +392,15 @@
                    const std::vector<uint8_t>& custom_option,
                    const std::function<TfLiteRegistration*()>& registration);
 
+  // Allocate tensors and apply delegate.
+  // Note that this is called by default in BuiltInterpreter().
+  void AllocateAndDelegate(bool apply_delegate);
+
   // Build the interpreter for this model. Also, resize and allocate all
   // tensors given the shapes of the inputs.
   void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
                         int num_threads, bool allow_fp32_relax_to_fp16,
-                        bool apply_delegate);
+                        bool apply_delegate, bool allocate_and_delegate = true);
 
   void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
 
diff --git a/tensorflow/lite/micro/CONTRIBUTING.md b/tensorflow/lite/micro/CONTRIBUTING.md
new file mode 100644
index 0000000..f5a974b
--- /dev/null
+++ b/tensorflow/lite/micro/CONTRIBUTING.md
@@ -0,0 +1,168 @@
+# Resources
+
+A
+[TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
+should be the primary method of getting in touch with the TensorFlow Lite Micro
+(TFLM) team.
+
+The following resources may also be useful:
+
+1.  SIG Micro [email group](https://groups.google.com/a/tensorflow.org/g/micro)
+    and
+    [monthly meetings](http://doc/1YHq9rmhrOUdcZnrEnVCWvd87s2wQbq4z17HbeRl-DBc).
+
+1.  SIG Micro [gitter chat room](https://gitter.im/tensorflow/sig-micro).
+
+# Contributing Guidelines
+
+We look forward to your contributions to the TensorFlow Lite Micro codebase and
+provide guidelines with the goal of enabling community contributions while still
+maintaining code health, maintainability, and consistency in style.
+
+Please note that while these guidelines may seem onerous to some developers,
+they are derived from Google's software engineering best practices.
+
+Before we describe project-specific guidelines, we recommend that external
+contributors read these tips from the Google Testing Blog:
+
+*   [Code Health: Providing Context with Commit Messages and Bug Reports](https://testing.googleblog.com/2017/09/code-health-providing-context-with.html)
+*   [Code Health: Understanding Code In Review](https://testing.googleblog.com/2018/05/code-health-understanding-code-in-review.html)
+*   [Code Health: Too Many Comments on Your Code Reviews?](https://testing.googleblog.com/2017/06/code-health-too-many-comments-on-your.html)
+*   [Code Health: To Comment or Not to Comment?](https://testing.googleblog.com/2017/07/code-health-to-comment-or-not-to-comment.html)
+
+We also recommend that contributors take a look at the
+[Tensorflow Contributing Guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md).
+
+## General Pull Request Guidelines
+
+We strongly recommend that contributors:
+
+1.  Initiate a conversation with the TFLM team via a
+    [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
+    as early as possible.
+
+    *   This enables us to give guidance on how to proceed, prevent duplicated
+        effort and also point to alternatives as well as context if we are not
+        able to accept a particular contribution at a given time.
+
+    *   Ideally, you should make an issue ***before*** starting to work on a
+        pull request and provide context on both what you want to contribute and
+        why.
+
+1.  Once step 1. is complete and it is determined that a PR from an external
+    contributor is the way to go, please follow these guidelines from
+    [Google's Engineering Practices documentation](https://google.github.io/eng-practices/):
+
+    *   [Send Small Pull Requests](https://google.github.io/eng-practices/review/developer/small-cls.html)
+
+        *   If a pull request is doing more than one thing, the reviewer will
+            request that it be broken up into two or more PRs.
+
+    *   [Write Good Pull Request Descriptions](https://google.github.io/eng-practices/review/developer/cl-descriptions.html)
+
+        *   We require that all PR descriptions link to the github issue created
+            in step 1.
+
+        *   While github offers flexibility in linking
+            [commits and issues](https://github.blog/2011-04-09-issues-2-0-the-next-generation/#commits-issues),
+            we require that the PR description have a separate line with either
+            `Fixes #nn` (if the PR fixes the issue) or `Issue #nn` if the PR
+            addresses some aspect of an issue without fixing it.
+
+        *   We will be adding internal checks that automate this requirement by
+            matching the PR description to the regexp: `(Fixes|Issue) #`
+
+1.  Unit tests are critical to a healthy codebase. PRs without tests should be
+    the exception rather than the norm. And contributions to improve, simplify,
+    or make the unit tests more exhaustive are welcome! Please refer to
+    [this guideline](https://google.github.io/eng-practices/review/developer/small-cls.html#test_code)
+    on how test code and writing small PRs should be reconciled.
+
+## Guidelines for Specific Contribution Categories
+
+We provide some additional guidelines for different categories of contributions.
+
+### Bug Fixes
+
+Pull requests that fix bugs are always welcome and often uncontroversial, unless
+there is a conflict between different requirements from the platform, or if
+fixing a bug needs a bigger architectural change.
+
+1.  Create a
+    [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
+    to determine the scope of the bug fix.
+1.  Send a PR (if that is determined to be the best path forward).
+1.  Bugfix PRs should be accompanied by a test case that fails prior to the fix
+    and passes with the fix. This validates that the fix works as expected, and
+    helps prevent future regressions.
+
+### Reference Kernel Implementations
+
+Pull requests that port reference kernels from TF Lite Mobile to TF Lite Micro
+are welcome once we have enouch context from the contributor on why the
+additional kernel is needed.
+
+1.  Please create a
+    [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
+    before starting on any such PRs with as much context as possible, such as:
+
+    *   What is the model architecture?
+    *   What is the application that you are targetting?
+    *   What embedded target(s) are you planning to run on?
+    *   Motivate your use-case and the need for adding support for this
+        additional OP.
+
+1.  In the interest of having
+    [small pull requests](https://google.github.io/eng-practices/review/developer/small-cls.html),
+    limit each pull request to porting a single kernel (and the corresponding
+    test).
+
+1.  TODO(b/165627437): Create and link to a guide to porting reference ops.
+
+### Optimized Kernel Implementations
+
+In order to have the TFLM codebase be a central repository of optimized kernel
+implementations, we would like to make some improvements to the current
+infrastructure to enable adding and maintaining optimized kernel implementations
+in a scalable way.
+
+Until that work is complete, we are requesting a ***pause*** on contributions that
+add new optimized kernel implementations. We plan to make these improvements by
+October 2020 and will provide additional guidelines at that time.
+
+*   If you would like to have an exception to this pause, with the understanding
+    that your optimized kernels will break as we improve the underlying
+    framework, then please send an email to the [SIG Micro email
+    group](https://groups.google.com/a/tensorflow.org/g/micro) to figure out
+    a middle ground.
+
+*   Every optimized kernel directory must have a README.md with the github IDs
+    of the maintainers and any other relevant documentation. PRs that add
+    maintainers to the existing optimized kernels are always welcome.
+
+### New Target / Platform / IDE / Examples
+
+As discussed in the
+[SIG-micro Aug 12, 2020 meeting](http://doc/1YHq9rmhrOUdcZnrEnVCWvd87s2wQbq4z17HbeRl-DBc),
+we are currently ***pausing*** accepting pull requests that add new targets,
+platforms, IDE integration or examples while we revisit some of the
+infrastructure to enable us to make this process easier and more scalable.
+
+In the meantime, snapshotting and/or forking the tensorflow repo could be a
+viable way to prototype platform support.
+
+Having said that, we still invite
+[TF Lite Micro Github issues](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
+on this topic as we would like to enable such integration in the future.
+
+### New Features
+
+As discussed in the
+[SIG-micro Aug 12, 2020 meeting](http://doc/1YHq9rmhrOUdcZnrEnVCWvd87s2wQbq4z17HbeRl-DBc),
+we are currently ***pausing*** accepting pull requests that add new features while
+we revisit some of the infrastructure to enable us to make this process easier
+and more scalable.
+
+Having said that, we still invite feature requests via
+[TF Lite Micro Github issues](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
+to determine if the requested feature aligns with the TFLM roadmap.
diff --git a/tensorflow/lite/micro/benchmarks/Makefile.inc b/tensorflow/lite/micro/benchmarks/Makefile.inc
index 4a57ef3..d9dfba2 100644
--- a/tensorflow/lite/micro/benchmarks/Makefile.inc
+++ b/tensorflow/lite/micro/benchmarks/Makefile.inc
@@ -1,7 +1,3 @@
-$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
-$(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,))
-
-
 KEYWORD_BENCHMARK_SRCS := \
 tensorflow/lite/micro/benchmarks/keyword_benchmark.cc \
 tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.cc
diff --git a/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc b/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc
index 2e72709..87f2cdf 100644
--- a/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc
+++ b/tensorflow/lite/micro/examples/hello_world/sparkfun_edge/output_handler.cc
@@ -55,7 +55,7 @@
     // The blue LED is lit for all negative values
     am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_BLUE);
     // The red LED is lit in only some cases
-    if (y_value <= -0.75) {
+    if (y_value <= -0.75f) {
       am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_RED);
     } else {
       am_devices_led_off(am_bsp_psLEDs, AM_BSP_LED_RED);
@@ -68,13 +68,14 @@
     // The green LED is lit for all positive values
     am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_GREEN);
     // The yellow LED is lit in only some cases
-    if (y_value >= 0.75) {
+    if (y_value >= 0.75f) {
       am_devices_led_on(am_bsp_psLEDs, AM_BSP_LED_YELLOW);
     } else {
       am_devices_led_off(am_bsp_psLEDs, AM_BSP_LED_YELLOW);
     }
   }
   // Log the current X and Y values
-  TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n", x_value,
-                       y_value);
+  TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n",
+                       static_cast<double>(x_value),
+                       static_cast<double>(y_value));
 }
diff --git a/tensorflow/lite/micro/examples/person_detection/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/Makefile.inc
index a295bb8..304dd95 100644
--- a/tensorflow/lite/micro/examples/person_detection/Makefile.inc
+++ b/tensorflow/lite/micro/examples/person_detection/Makefile.inc
@@ -1,6 +1,3 @@
-$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
-$(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,))
-
 person_detection_MODEL_SRCS := \
 tensorflow/lite/micro/examples/person_detection/model_settings.cc \
 $(MAKEFILE_DIR)/downloads/person_model_grayscale/person_detect_model_data.cc
diff --git a/tensorflow/lite/micro/kernels/add.cc b/tensorflow/lite/micro/kernels/add.cc
index 79a0487..7c63eea 100644
--- a/tensorflow/lite/micro/kernels/add.cc
+++ b/tensorflow/lite/micro/kernels/add.cc
@@ -110,19 +110,22 @@
   tflite::ArithmeticParams op_params;
   SetActivationParams(data->output_activation_min_f32,
                       data->output_activation_max_f32, &op_params);
-#define TF_LITE_ADD(opname)                                               \
-  reference_ops::opname(op_params, tflite::micro::GetTensorShape(input1), \
-                        tflite::micro::GetTensorData<float>(input1),      \
-                        tflite::micro::GetTensorShape(input2),            \
-                        tflite::micro::GetTensorData<float>(input2),      \
-                        tflite::micro::GetTensorShape(output),            \
-                        tflite::micro::GetTensorData<float>(output))
   if (data->requires_broadcast) {
-    TF_LITE_ADD(BroadcastAdd4DSlow);
+    reference_ops::BroadcastAdd4DSlow(
+        op_params, tflite::micro::GetTensorShape(input1),
+        tflite::micro::GetTensorData<float>(input1),
+        tflite::micro::GetTensorShape(input2),
+        tflite::micro::GetTensorData<float>(input2),
+        tflite::micro::GetTensorShape(output),
+        tflite::micro::GetTensorData<float>(output));
   } else {
-    TF_LITE_ADD(Add);
+    reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
+                       tflite::micro::GetTensorData<float>(input1),
+                       tflite::micro::GetTensorShape(input2),
+                       tflite::micro::GetTensorData<float>(input2),
+                       tflite::micro::GetTensorShape(output),
+                       tflite::micro::GetTensorData<float>(output));
   }
-#undef TF_LITE_ADD
 }
 
 TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
@@ -147,27 +150,42 @@
     bool need_broadcast = reference_ops::ProcessBroadcastShapes(
         tflite::micro::GetTensorShape(input1),
         tflite::micro::GetTensorShape(input2), &op_params);
-#define TF_LITE_ADD(type, opname, dtype)                         \
-  type::opname(op_params, tflite::micro::GetTensorShape(input1), \
-               tflite::micro::GetTensorData<dtype>(input1),      \
-               tflite::micro::GetTensorShape(input2),            \
-               tflite::micro::GetTensorData<dtype>(input2),      \
-               tflite::micro::GetTensorShape(output),            \
-               tflite::micro::GetTensorData<dtype>(output));
     if (output->type == kTfLiteInt8) {
       if (need_broadcast) {
-        TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t);
+        reference_integer_ops::BroadcastAdd4DSlow(
+            op_params, tflite::micro::GetTensorShape(input1),
+            tflite::micro::GetTensorData<int8_t>(input1),
+            tflite::micro::GetTensorShape(input2),
+            tflite::micro::GetTensorData<int8_t>(input2),
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<int8_t>(output));
       } else {
-        TF_LITE_ADD(reference_integer_ops, Add, int8_t);
+        reference_integer_ops::Add(
+            op_params, tflite::micro::GetTensorShape(input1),
+            tflite::micro::GetTensorData<int8_t>(input1),
+            tflite::micro::GetTensorShape(input2),
+            tflite::micro::GetTensorData<int8_t>(input2),
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<int8_t>(output));
       }
     } else {
       if (need_broadcast) {
-        TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, uint8_t);
+        reference_ops::BroadcastAdd4DSlow(
+            op_params, tflite::micro::GetTensorShape(input1),
+            tflite::micro::GetTensorData<uint8_t>(input1),
+            tflite::micro::GetTensorShape(input2),
+            tflite::micro::GetTensorData<uint8_t>(input2),
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<uint8_t>(output));
       } else {
-        TF_LITE_ADD(reference_ops, Add, uint8_t);
+        reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
+                           tflite::micro::GetTensorData<uint8_t>(input1),
+                           tflite::micro::GetTensorShape(input2),
+                           tflite::micro::GetTensorData<uint8_t>(input2),
+                           tflite::micro::GetTensorShape(output),
+                           tflite::micro::GetTensorData<uint8_t>(output));
       }
     }
-#undef TF_LITE_ADD
   }
 
   return kTfLiteOk;
diff --git a/tensorflow/lite/micro/kernels/circular_buffer.cc b/tensorflow/lite/micro/kernels/circular_buffer.cc
index b5a8ae1..7f5aeba 100644
--- a/tensorflow/lite/micro/kernels/circular_buffer.cc
+++ b/tensorflow/lite/micro/kernels/circular_buffer.cc
@@ -17,8 +17,6 @@
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/compatibility.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
-#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
-#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
 #include "tensorflow/lite/kernels/op_macros.h"
diff --git a/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc b/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc
index ddb1444..afdc564 100644
--- a/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc
+++ b/tensorflow/lite/micro/kernels/ethos-u/ethosu.cc
@@ -16,6 +16,7 @@
 #include <ethosu_driver.h>
 
 #include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/micro/kernels/kernel_util.h"
 #include "tensorflow/lite/micro/tools/make/downloads/flatbuffers/include/flatbuffers/flexbuffers.h"
 
 namespace tflite {
@@ -26,30 +27,51 @@
 
 constexpr uint8_t CO_TYPE_ETHOSU = 1;
 
+struct OpData {
+  int cms_data_size;
+  int buffer_idx;
+};
+
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
-  return nullptr;
+  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
+  return context->AllocatePersistentBuffer(context, sizeof(OpData));
 }
 
 void Free(TfLiteContext* context, void* buffer) {}
 
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TFLITE_DCHECK(context != nullptr);
   TF_LITE_ENSURE(context, node->inputs->size > 0);
-  TF_LITE_ENSURE(context, context->tensors);
+  TFLITE_DCHECK(node->user_data != nullptr);
   TF_LITE_ENSURE(context, node->custom_initial_data_size > 0);
+
+  OpData* data = static_cast<OpData*>(node->user_data);
+  int num_base_addr = node->inputs->size + node->outputs->size;
+  TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
+      context, num_base_addr * sizeof(uint64_t), &data->buffer_idx));
+
+  // Get command stream data size
+  TfLiteTensor* tensor = context->GetTensor(context, node->inputs->data[0]);
+  data->cms_data_size = tensor->bytes;
+
   return kTfLiteOk;
 }
 
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  TFLITE_DCHECK(node->user_data != nullptr);
+  TFLITE_DCHECK(context != nullptr);
+  TFLITE_DCHECK(context->GetScratchBuffer != nullptr);
+
   // Get base addresses
-  TfLiteTensor* tensor;
-  int num_base_addr = node->inputs->size + node->outputs->size;
+  TfLiteEvalTensor* tensor;
   int i = 0;
   int num_tensors = 0;
-  uint64_t base_addrs[num_base_addr];
   void* cms_data;
-  int cms_data_size;
   uint8_t co_type;
   int result;
+  const OpData* data = static_cast<const OpData*>(node->user_data);
+  uint64_t* base_addrs = static_cast<uint64_t*>(
+      context->GetScratchBuffer(context, data->buffer_idx));
 
   const uint8_t* custom_data =
       static_cast<uint8_t const*>(node->custom_initial_data);
@@ -60,26 +82,30 @@
     return kTfLiteError;
   }
 
-  // Get command stream data address and size
-  tensor = &(context->tensors[node->inputs->data[0]]);
+  // Get command stream data address
+  tensor = context->GetEvalTensor(context, node->inputs->data[0]);
   cms_data = reinterpret_cast<void*>(tensor->data.uint8);
-  cms_data_size = tensor->bytes;
 
   // Get adresses to weights/scratch/input data
   for (i = 1; i < node->inputs->size; ++i) {
-    tensor = &(context->tensors[node->inputs->data[i]]);
+    tensor = context->GetEvalTensor(context, node->inputs->data[i]);
     base_addrs[num_tensors] = reinterpret_cast<uint64_t>(tensor->data.uint8);
     num_tensors++;
   }
 
   // Get adresses to output data
   for (i = 0; i < node->outputs->size; ++i) {
-    tensor = &(context->tensors[node->outputs->data[i]]);
+    tensor = context->GetEvalTensor(context, node->outputs->data[i]);
     base_addrs[num_tensors] = reinterpret_cast<uint64_t>(tensor->data.uint8);
     num_tensors++;
   }
 
-  result = ethosu_invoke(cms_data, cms_data_size, base_addrs, num_tensors);
+  // Ethos-U guarantees that the tensors that require a base pointer are among
+  // the 8 first tensors
+  num_tensors = std::min(num_tensors, 8);
+
+  result =
+      ethosu_invoke(cms_data, data->cms_data_size, base_addrs, num_tensors);
   if (-1 == result) {
     return kTfLiteError;
   } else {
@@ -89,8 +115,16 @@
 
 }  // namespace ethosu
 
-TfLiteRegistration Register_ETHOSU() {
-  return {ethosu::Init, ethosu::Free, ethosu::Prepare, ethosu::Eval};
+TfLiteRegistration* Register_ETHOSU() {
+  static TfLiteRegistration r = {ethosu::Init,
+                                 ethosu::Free,
+                                 ethosu::Prepare,
+                                 ethosu::Eval,
+                                 /*profiling_string=*/nullptr,
+                                 /*builtin_code=*/0,
+                                 /*custom_name=*/nullptr,
+                                 /*version=*/0};
+  return &r;
 }
 
 const char* GetString_ETHOSU() { return "ethos-u"; }
diff --git a/tensorflow/lite/micro/kernels/reduce.cc b/tensorflow/lite/micro/kernels/reduce.cc
index 5cae782..2f9e639 100644
--- a/tensorflow/lite/micro/kernels/reduce.cc
+++ b/tensorflow/lite/micro/kernels/reduce.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/types.h"
 #include "tensorflow/lite/kernels/kernel_util.h"
@@ -32,6 +33,20 @@
 constexpr int kMaxNumberOfAxis = 4;
 constexpr int kMaxNumberOfReducedAxis = 2;
 
+struct OpData {
+  int32_t multiplier;
+  int shift;
+  int temp_buffer_idx;
+  int input_zp;
+  float input_scale;
+  int output_zp;
+  float output_scale;
+};
+
+void* InitMean(TfLiteContext* context, const char* buffer, size_t length) {
+  return context->AllocatePersistentBuffer(context, sizeof(OpData));
+}
+
 TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
   // Inputs Tensor (dtype depends on quantization):
   // [0] = Input
@@ -51,6 +66,25 @@
 }
 
 TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, 0);
+  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+  const TfLiteTensor* output = GetOutput(context, node, 0);
+  if (input->type == kTfLiteInt8) {
+    const double real_multiplier = static_cast<double>(input->params.scale) /
+                                   static_cast<double>(output->params.scale);
+    QuantizeMultiplier(real_multiplier, &op_data->multiplier, &op_data->shift);
+  }
+
+  int output_size = NumElements(output);
+  if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
+    context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
+                                         &op_data->temp_buffer_idx);
+    op_data->input_zp = input->params.zero_point;
+    op_data->input_scale = input->params.scale;
+    op_data->output_zp = output->params.zero_point;
+    op_data->output_scale = output->params.scale;
+  }
+
   TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
   // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
   return kTfLiteOk;
@@ -74,26 +108,25 @@
   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
   TfLiteReducerParams* params =
       reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
+  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
 
   int num_axis = static_cast<int>(ElementCount(*axis->dims));
   int temp_index[kMaxNumberOfAxis];
   int resolved_axis[kMaxNumberOfReducedAxis];
 
+  tflite::MeanParams op_params;
+  ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
+  // TODO(b/146571391): Support only 4D Input and 2D Axis for Mean until
+  // scratch tensor allocation has been implemented in (b/132070898)
+  bool is_valid_inputs = (input->dims->size == 4 && op_params.axis_count == 2 &&
+                          ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+                           (op_params.axis[0] == 2 && op_params.axis[1] == 1)));
+  TF_LITE_ENSURE_MSG(
+      context, is_valid_inputs == true,
+      "Number of Input "
+      "dimensions != 4 OR the Axis is not either [1, 2] or [2, 1]");
   switch (input->type) {
     case kTfLiteFloat32: {
-      tflite::MeanParams op_params;
-      ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis,
-                  &op_params);
-      // TODO(b/146571391): Support only 4D Input and 2D Axis for Mean until
-      // scratch tensor allocation has been implemented in (b/132070898)
-      bool is_valid_inputs =
-          (input->dims->size == 4 && op_params.axis_count == 2 &&
-           ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
-            (op_params.axis[0] == 2 && op_params.axis[1] == 1)));
-      TF_LITE_ENSURE_MSG(
-          context, is_valid_inputs == true,
-          "Number of Input "
-          "dimensions != 4 OR the Axis is not either [1, 2] or [2, 1]");
       // TODO(b/139102329): Handle the below special case in the combined
       // reference method.
       // Defer to specialized implementation for 4D Mean across axes 1 & 2.
@@ -114,10 +147,81 @@
                 tflite::micro::GetTensorData<float>(output)));
       }
     } break;
+    case kTfLiteInt8: {
+      if (params->keep_dims) {
+        reference_integer_ops::Mean(
+            op_params, op_data->multiplier, op_data->shift,
+            tflite::micro::GetTensorShape(input),
+            tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
+            tflite::micro::GetTensorShape(output),
+            tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
+      } else if (op_data->input_zp == op_data->output_zp &&
+                 op_data->input_scale == op_data->output_scale) {
+        int32_t* temp_buffer = static_cast<int32_t*>(
+            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
+        TF_LITE_ENSURE(
+            context,
+            reference_ops::Mean(
+                tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
+                input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
+                output->dims->data, output->dims->size,
+                tflite::micro::GetTensorData<int>(axis), num_axis,
+                params->keep_dims, temp_index, resolved_axis, temp_buffer));
+      } else {
+        int32_t* temp_buffer = static_cast<int32_t*>(
+            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
+        TF_LITE_ENSURE(
+            context,
+            reference_ops::QuantizedMeanOrSum(
+                tflite::micro::GetTensorData<int8_t>(input), op_data->input_zp,
+                op_data->input_scale, input->dims->data, input->dims->size,
+                tflite::micro::GetTensorData<int8_t>(output),
+                op_data->output_zp, op_data->output_scale, output->dims->data,
+                output->dims->size, tflite::micro::GetTensorData<int>(axis),
+                num_axis, params->keep_dims, temp_index, resolved_axis,
+                temp_buffer, false));
+      }
+    } break;
+    case kTfLiteUInt8: {
+      if (params->keep_dims) {
+        reference_ops::Mean(op_params, tflite::micro::GetTensorShape(input),
+                            tflite::micro::GetTensorData<uint8_t>(input),
+                            op_data->input_zp, op_data->input_scale,
+                            tflite::micro::GetTensorShape(output),
+                            tflite::micro::GetTensorData<uint8_t>(output),
+                            op_data->output_zp, op_data->output_scale);
+      } else if (op_data->input_zp == op_data->output_zp &&
+                 op_data->input_scale == op_data->output_scale) {
+        uint32_t* temp_buffer = static_cast<uint32_t*>(
+            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
+        TF_LITE_ENSURE(
+            context,
+            reference_ops::Mean(tflite::micro::GetTensorData<uint8_t>(input),
+                                input->dims->data, input->dims->size,
+                                tflite::micro::GetTensorData<uint8_t>(output),
+                                output->dims->data, output->dims->size,
+                                tflite::micro::GetTensorData<int>(axis),
+                                num_axis, params->keep_dims, temp_index,
+                                resolved_axis, temp_buffer));
+      } else {
+        uint32_t* temp_buffer = static_cast<uint32_t*>(
+            context->GetScratchBuffer(context, op_data->temp_buffer_idx));
+        TF_LITE_ENSURE(
+            context,
+            reference_ops::QuantizedMeanOrSum(
+                tflite::micro::GetTensorData<uint8_t>(input), op_data->input_zp,
+                op_data->input_scale, input->dims->data, input->dims->size,
+                tflite::micro::GetTensorData<uint8_t>(output),
+                op_data->output_zp, op_data->output_scale, output->dims->data,
+                output->dims->size, tflite::micro::GetTensorData<int>(axis),
+                num_axis, params->keep_dims, temp_index, resolved_axis,
+                temp_buffer, false));
+      }
+    } break;
     default:
       // TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
       TF_LITE_ENSURE_MSG(context, false,
-                         "Currently, only float32 input type "
+                         "Currently, only float32, int8 or uint8 input type "
                          "is supported.");
   }
   return kTfLiteOk;
@@ -125,7 +229,7 @@
 }  // namespace reduce
 
 TfLiteRegistration Register_MEAN() {
-  return {/*init=*/nullptr,
+  return {/*init=*/reduce::InitMean,
           /*free=*/nullptr,
           /*prepare=*/reduce::PrepareMeanOrSum,
           /*invoke=*/reduce::EvalMean,
diff --git a/tensorflow/lite/micro/kernels/reduce_test.cc b/tensorflow/lite/micro/kernels/reduce_test.cc
index 1e3ded2..3207063 100644
--- a/tensorflow/lite/micro/kernels/reduce_test.cc
+++ b/tensorflow/lite/micro/kernels/reduce_test.cc
@@ -25,7 +25,7 @@
 namespace {
 
 // Common inputs and outputs.
-// static const int kInputElements4D = 24;
+static const int kInputElements4D = 24;
 static const int kInputShape4D[] = {4, 2, 2, 3, 2};
 static const float kInputData4D[] = {
     1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
@@ -90,6 +90,44 @@
                             output_data, output_dims_count, params, tolerance));
 }
 
+template <typename T>
+void TestMeanOpQuantized(const int* input_dims_data, const float* input_data,
+                         T* input_data_quant, float input_scale,
+                         int input_zero_point, const int* axis_dims_data,
+                         const int32_t* axis_data, const int* output_dims_data,
+                         const float* expected_output_data,
+                         T* output_data_quant, T* expected_output_data_quant,
+                         float output_scale, int output_zero_point,
+                         TfLiteReducerParams* params) {
+  // Convert dimesion arguments to TfLiteArrays
+  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+  TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+
+  // Get number of elements in input and output tensors
+  const int output_dims_count = ElementCount(*output_dims);
+
+  // Initialize tensors
+  constexpr int tensors_size = 3;
+  TfLiteTensor tensors[] = {
+      CreateQuantizedTensor(input_data, input_data_quant, input_dims,
+                            input_scale, input_zero_point),
+      CreateInt32Tensor(axis_data, axis_dims),
+      CreateQuantizedTensor(output_data_quant, output_dims, output_scale,
+                            output_zero_point),
+  };
+
+  // Quantize expected output
+  tflite::AsymmetricQuantize(expected_output_data, expected_output_data_quant,
+                             output_dims_count, output_scale,
+                             output_zero_point);
+
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk,
+      ValidateReduceGoldens(tensors, tensors_size, expected_output_data_quant,
+                            output_data_quant, output_dims_count, params, 1.0));
+}
+
 }  // namespace
 }  // namespace testing
 }  // namespace tflite
@@ -110,10 +148,55 @@
       &params);
 }
 
+TF_LITE_MICRO_TEST(MeanInt84DKeepDims) {
+  int8_t expected_output_data_quant[tflite::testing::kOutputElements];
+  int8_t output_data_quant[tflite::testing::kOutputElements];
+  int8_t input_data_quant[tflite::testing::kInputElements4D];
+
+  float input_scale = 0.5f;
+  int input_zero_point = 0;
+  float output_scale = 0.5f;
+  int output_zero_point = 0;
+
+  TfLiteReducerParams params = {
+      true  // keep_dims
+  };
+
+  tflite::testing::TestMeanOpQuantized<int8_t>(
+      tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
+      input_data_quant, input_scale, input_zero_point,
+      tflite::testing::kAxisShape, tflite::testing::kAxisData,
+      tflite::testing::kOutputShape, tflite::testing::kGoldenData,
+      output_data_quant, expected_output_data_quant, output_scale,
+      output_zero_point, &params);
+}
+
+TF_LITE_MICRO_TEST(MeanUInt84DKeepDims) {
+  uint8_t expected_output_data_quant[tflite::testing::kOutputElements];
+  uint8_t output_data_quant[tflite::testing::kOutputElements];
+  uint8_t input_data_quant[tflite::testing::kInputElements4D];
+
+  float input_scale = 0.5f;
+  int input_zero_point = 128;
+  float output_scale = 0.5f;
+  int output_zero_point = 128;
+
+  TfLiteReducerParams params = {
+      true  // keep_dims
+  };
+
+  tflite::testing::TestMeanOpQuantized<uint8_t>(
+      tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
+      input_data_quant, input_scale, input_zero_point,
+      tflite::testing::kAxisShape, tflite::testing::kAxisData,
+      tflite::testing::kOutputShape, tflite::testing::kGoldenData,
+      output_data_quant, expected_output_data_quant, output_scale,
+      output_zero_point, &params);
+}
+
 TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDims) {
   const int kOutputShape[] = {2, 2, 2};
   float output_data[tflite::testing::kOutputElements];
-
   TfLiteReducerParams params = {
       false  // keep_dims
   };
@@ -124,6 +207,50 @@
       tflite::testing::kGoldenData, output_data, &params);
 }
 
+TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDims) {
+  int8_t expected_output_data_quant[tflite::testing::kOutputElements];
+  int8_t output_data_quant[tflite::testing::kOutputElements];
+  int8_t input_data_quant[tflite::testing::kInputElements4D];
+
+  const int kOutputShape[] = {2, 2, 2};
+  TfLiteReducerParams params = {
+      false  // keep_dims
+  };
+  float input_scale = 0.5f;
+  int input_zero_point = 0;
+  float output_scale = 0.5f;
+  int output_zero_point = 0;
+
+  tflite::testing::TestMeanOpQuantized<int8_t>(
+      tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
+      input_data_quant, input_scale, input_zero_point,
+      tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape,
+      tflite::testing::kGoldenData, output_data_quant,
+      expected_output_data_quant, output_scale, output_zero_point, &params);
+}
+
+TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDims) {
+  uint8_t expected_output_data_quant[tflite::testing::kOutputElements];
+  uint8_t output_data_quant[tflite::testing::kOutputElements];
+  uint8_t input_data_quant[tflite::testing::kInputElements4D];
+
+  const int kOutputShape[] = {2, 2, 2};
+  TfLiteReducerParams params = {
+      false  // keep_dims
+  };
+  float input_scale = 0.5f;
+  int input_zero_point = 128;
+  float output_scale = 0.5f;
+  int output_zero_point = 128;
+
+  tflite::testing::TestMeanOpQuantized<uint8_t>(
+      tflite::testing::kInputShape4D, tflite::testing::kInputData4D,
+      input_data_quant, input_scale, input_zero_point,
+      tflite::testing::kAxisShape, tflite::testing::kAxisData, kOutputShape,
+      tflite::testing::kGoldenData, output_data_quant,
+      expected_output_data_quant, output_scale, output_zero_point, &params);
+}
+
 TF_LITE_MICRO_TEST(MeanFloat4DWithoutKeepDimsWithPrecision) {
   const int kInputShape4D[] = {4, 2, 2, 3, 1};
   const float kInputData4D[] = {1.0,  24.0, 13.0, 3.0,  9.0,  17.0,
@@ -132,7 +259,6 @@
   const int kOutputShape[] = {2, 2, 1};
   const float kGoldenData[] = {11.166667, 19.833334};
   float output_data[kOutputElements];
-
   TfLiteReducerParams params = {
       false  // keep_dims
   };
@@ -143,4 +269,54 @@
       &params);
 }
 
+TF_LITE_MICRO_TEST(MeanInt84DWithoutKeepDimsWithPrecision) {
+  const int kInputShape4D[] = {4, 2, 2, 3, 1};
+  const float kInputData4D[] = {1.0,  24.0, 13.0, 3.0,  9.0,  17.0,
+                                11.0, 36.0, 14.0, 19.0, 17.0, 22.0};
+  const int kOutputShape[] = {2, 2, 1};
+  const float kGoldenData[] = {11.166667, 19.833334};
+  TfLiteReducerParams params = {
+      false  // keep_dims
+  };
+  float input_scale = 0.5f;
+  int input_zero_point = 0;
+  float output_scale = 0.5f;
+  int output_zero_point = 0;
+
+  int8_t output_data_quant[2];
+  int8_t expected_output_data_quant[2];
+  int8_t input_data_quant[12];
+
+  tflite::testing::TestMeanOpQuantized<int8_t>(
+      kInputShape4D, kInputData4D, input_data_quant, input_scale,
+      input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData,
+      kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant,
+      output_scale, output_zero_point, &params);
+}
+
+TF_LITE_MICRO_TEST(MeanUInt84DWithoutKeepDimsWithPrecision) {
+  const int kInputShape4D[] = {4, 2, 2, 3, 1};
+  const float kInputData4D[] = {1.0,  24.0, 13.0, 3.0,  9.0,  17.0,
+                                11.0, 36.0, 14.0, 19.0, 17.0, 22.0};
+  const int kOutputShape[] = {2, 2, 1};
+  const float kGoldenData[] = {11.166667, 19.833334};
+  TfLiteReducerParams params = {
+      false  // keep_dims
+  };
+
+  float input_scale = 0.5f;
+  int input_zero_point = 128;
+  float output_scale = 0.5f;
+  int output_zero_point = 128;
+
+  uint8_t output_data_quant[2];
+  uint8_t expected_output_data_quant[2];
+  uint8_t input_data_quant[12];
+
+  tflite::testing::TestMeanOpQuantized<uint8_t>(
+      kInputShape4D, kInputData4D, input_data_quant, input_scale,
+      input_zero_point, tflite::testing::kAxisShape, tflite::testing::kAxisData,
+      kOutputShape, kGoldenData, output_data_quant, expected_output_data_quant,
+      output_scale, output_zero_point, &params);
+}
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc
index e85c1a4..fa1b9ca 100644
--- a/tensorflow/lite/micro/kernels/softmax.cc
+++ b/tensorflow/lite/micro/kernels/softmax.cc
@@ -30,23 +30,30 @@
 namespace activations {
 namespace {
 
+// Softmax parameter data that persists in user_data
+static constexpr int kInt16LUTArraySize = 513;
+
 TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
                                     const TfLiteTensor* input,
                                     TfLiteTensor* output,
                                     const TfLiteSoftmaxParams* params,
                                     SoftmaxParams* op_data) {
-  if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
+  if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
+      input->type == kTfLiteInt16) {
     if (input->type == kTfLiteUInt8) {
       TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
-    } else {
+    } else if (input->type == kTfLiteInt16) {
+      TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+      TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
+                          (0.001f * 1.f / 32768));
+    } else {  // input->type == kTfLiteInt8
       TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
       if (output->type == kTfLiteInt16) {
         TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
-        // NOTE: Current int16_t softmax output does not require symmetric
-        // scaling
-        // - so no need to verify scale here.
-      } else {
+        TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536,
+                            (0.001f * 1.f / 65536));
+      } else {  // output->type == kTfLiteint8
         TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
         TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
         TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
@@ -55,15 +62,28 @@
 
     static const int kScaledDiffIntegerBits = 5;
 
-    int input_left_shift;
-    tflite::PreprocessSoftmaxScaling(
-        static_cast<double>(params->beta),
-        static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
-        &op_data->input_multiplier, &input_left_shift);
-    op_data->input_left_shift = input_left_shift;
-    op_data->diff_min =
-        -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
-                                            op_data->input_left_shift);
+    // Calculate input_multiplier and input_left_shift
+    if (input->type == kTfLiteInt16) {
+      int input_left_shift;
+      double input_scale_beta_rescale =
+          static_cast<double>(input->params.scale) *
+          static_cast<double>(params->beta) /
+          (10.0 / 65535.0);  // scale the input_diff such that [-65535, 0]
+                             // correspond to [-10.0, 0.0]
+      QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier,
+                         &input_left_shift);
+      op_data->input_left_shift = input_left_shift;
+    } else {
+      int input_left_shift;
+      tflite::PreprocessSoftmaxScaling(
+          static_cast<double>(params->beta),
+          static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
+          &op_data->input_multiplier, &input_left_shift);
+      op_data->input_left_shift = input_left_shift;
+      op_data->diff_min =
+          -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
+                                              op_data->input_left_shift);
+    }
   } else {
     TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
@@ -91,7 +111,7 @@
         tflite::micro::GetTensorData<uint8_t>(input),
         tflite::micro::GetTensorShape(output),
         tflite::micro::GetTensorData<uint8_t>(output));
-  } else {
+  } else if (input->type == kTfLiteInt8) {
     if (output->type == kTfLiteInt16) {
       tflite::reference_ops::Softmax(
           op_data, tflite::micro::GetTensorShape(input),
@@ -105,6 +125,12 @@
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<int8_t>(output));
     }
+  } else {
+    tflite::reference_ops::SoftmaxInt16(
+        op_data, tflite::micro::GetTensorShape(input),
+        tflite::micro::GetTensorData<int16_t>(input),
+        tflite::micro::GetTensorShape(output),
+        tflite::micro::GetTensorData<int16_t>(output));
   }
 }
 
@@ -114,18 +140,50 @@
 }
 
 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
-  auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
-
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
   const TfLiteTensor* input = GetInput(context, node, 0);
   TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
-
   TfLiteTensor* output = GetOutput(context, node, 0);
 
-  TFLITE_DCHECK(node->user_data != nullptr);
-  SoftmaxParams* data = static_cast<SoftmaxParams*>(node->user_data);
-  return CalculateSoftmaxParams(context, input, output, params, data);
+  TF_LITE_ENSURE(context, node->user_data != nullptr);
+  SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
+  // Only allocate LUTs for KTfLiteInt16 data type
+  if (input->type == kTfLiteInt16) {
+    void* raw_exp_lut = context->AllocatePersistentBuffer(
+        context, sizeof(int16_t) * kInt16LUTArraySize);
+    TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
+    op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
+    void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
+        context, sizeof(int16_t) * kInt16LUTArraySize);
+    TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
+    op_data->one_over_one_plus_x_lut =
+        reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
+  }
+
+  if (output->type == kTfLiteInt16) {
+    TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
+                                input->type == kTfLiteUInt8 ||
+                                input->type == kTfLiteInt16);
+  } else {
+    TF_LITE_ENSURE_EQ(context, input->type, output->type);
+  }
+
+  // Populate LUT if required
+  if (input->type == kTfLiteInt16) {
+    TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+    // exp LUT only used on negative values
+    // we consider exp(-10.0) is insignificant to accumulation
+    gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f,
+            op_data->exp_lut, kInt16LUTArraySize);
+    gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f,
+            op_data->one_over_one_plus_x_lut, kInt16LUTArraySize);
+    op_data->zero_point = output->params.zero_point;
+    op_data->scale = output->params.scale;
+  }
+
+  auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+  return CalculateSoftmaxParams(context, input, output, params, op_data);
 }
 
 TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
@@ -133,16 +191,17 @@
   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
 
   TFLITE_DCHECK(node->user_data != nullptr);
-  SoftmaxParams* data = static_cast<SoftmaxParams*>(node->user_data);
+  SoftmaxParams op_data = *static_cast<SoftmaxParams*>(node->user_data);
 
   switch (input->type) {
     case kTfLiteFloat32: {
-      SoftmaxFloat(input, output, *data);
+      SoftmaxFloat(input, output, op_data);
       return kTfLiteOk;
     }
     case kTfLiteInt8:
-    case kTfLiteUInt8: {
-      SoftmaxQuantized(input, output, *data);
+    case kTfLiteUInt8:
+    case kTfLiteInt16: {
+      SoftmaxQuantized(input, output, op_data);
       return kTfLiteOk;
     }
     default:
diff --git a/tensorflow/lite/micro/kernels/softmax_test.cc b/tensorflow/lite/micro/kernels/softmax_test.cc
index 27828d2..808ea93 100644
--- a/tensorflow/lite/micro/kernels/softmax_test.cc
+++ b/tensorflow/lite/micro/kernels/softmax_test.cc
@@ -28,8 +28,13 @@
 // quantization parameters.
 const float output_scale_int8 = 1.0f / 256.0f;
 const float output_scale_uint8 = 1.0f / 256.0f;
+const float output_scale_int16 = 1.0f / 32768.0f;
 const int output_zero_point_int8 = -128;
 const int output_zero_point_uint8 = 0;
+const int output_zero_point_int16 = 0;
+
+// Empirical tolerance in quantization space
+const float tolerance_int16 = 7.0;
 
 // 1-dimensional test data.
 const int flat_size_1d = 5;
@@ -291,7 +296,7 @@
                           int input_zero_point, const int* output_dims_data,
                           const float* golden, T* golden_quantized,
                           float output_scale, int output_zero_point,
-                          T* output_data) {
+                          T* output_data, float tolerance = 1.0) {
   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
   const int output_dims_count = ElementCount(*output_dims);
@@ -310,7 +315,7 @@
                      output_zero_point);
 
   ValidateSoftmaxGoldens(tensors, tensors_size, output_data, golden_quantized,
-                         output_dims_count, 1.0);
+                         output_dims_count, tolerance);
 }
 
 }  // namespace
@@ -356,6 +361,21 @@
       tflite::testing::output_zero_point_int8, output_data);
 }
 
+TF_LITE_MICRO_TEST(Softmax1DQuantizedInt16ShouldMatchGolden) {
+  const float input_scale = 0.1f;
+  const int input_zero_point = 0;
+
+  int16_t input_quantized[tflite::testing::flat_size_1d];
+  int16_t golden_quantized[tflite::testing::flat_size_1d];
+  int16_t output_data[tflite::testing::flat_size_1d];
+  tflite::testing::TestSoftmaxQuantized(
+      tflite::testing::shape_1d, tflite::testing::input_data_1d,
+      input_quantized, input_scale, input_zero_point, tflite::testing::shape_1d,
+      tflite::testing::golden_1d, golden_quantized,
+      tflite::testing::output_scale_int16,
+      tflite::testing::output_zero_point_int16, output_data);
+}
+
 TF_LITE_MICRO_TEST(Softmax2DFloatShouldMatchGolden) {
   float output_data[tflite::testing::flat_size_2d];
   tflite::testing::TestSoftmaxFloat(
@@ -393,6 +413,21 @@
       tflite::testing::output_zero_point_int8, output_data);
 }
 
+TF_LITE_MICRO_TEST(Softmax2DQuantizedInt16ShouldMatchGolden) {
+  const float input_scale = 0.1f;
+  const int input_zero_point = 0;
+
+  int16_t input_quantized[tflite::testing::flat_size_2d];
+  int16_t golden_quantized[tflite::testing::flat_size_2d];
+  int16_t output_data[tflite::testing::flat_size_2d];
+  tflite::testing::TestSoftmaxQuantized(
+      tflite::testing::shape_2d, tflite::testing::input_data_2d,
+      input_quantized, input_scale, input_zero_point, tflite::testing::shape_2d,
+      tflite::testing::golden_2d, golden_quantized,
+      tflite::testing::output_scale_int16,
+      tflite::testing::output_zero_point_int16, output_data);
+}
+
 TF_LITE_MICRO_TEST(Softmax3DFloatShouldMatchGolden) {
   float output_data[tflite::testing::flat_size_3d];
   tflite::testing::TestSoftmaxFloat(
@@ -430,6 +465,22 @@
       tflite::testing::output_zero_point_int8, output_data);
 }
 
+TF_LITE_MICRO_TEST(Softmax3DQuantizedInt16ShouldMatchGolden) {
+  const float input_scale = 0.1f;
+  const int input_zero_point = 0;
+
+  int16_t input_quantized[tflite::testing::flat_size_3d];
+  int16_t golden_quantized[tflite::testing::flat_size_3d];
+  int16_t output_data[tflite::testing::flat_size_3d];
+  tflite::testing::TestSoftmaxQuantized(
+      tflite::testing::shape_3d, tflite::testing::input_data_3d,
+      input_quantized, input_scale, input_zero_point, tflite::testing::shape_3d,
+      tflite::testing::golden_3d, golden_quantized,
+      tflite::testing::output_scale_int16,
+      tflite::testing::output_zero_point_int16, output_data,
+      tflite::testing::tolerance_int16);
+}
+
 TF_LITE_MICRO_TEST(Softmax4DFloatShouldMatchGolden) {
   float output_data[tflite::testing::flat_size_4d];
   tflite::testing::TestSoftmaxFloat(
@@ -467,4 +518,19 @@
       tflite::testing::output_zero_point_int8, output_data);
 }
 
+TF_LITE_MICRO_TEST(Softmax4DQuantizedInt16ShouldMatchGolden) {
+  const float input_scale = 0.1f;
+  const int input_zero_point = 0;
+
+  int16_t input_quantized[tflite::testing::flat_size_4d];
+  int16_t golden_quantized[tflite::testing::flat_size_4d];
+  int16_t output_data[tflite::testing::flat_size_4d];
+  tflite::testing::TestSoftmaxQuantized(
+      tflite::testing::shape_4d, tflite::testing::input_data_4d,
+      input_quantized, input_scale, input_zero_point, tflite::testing::shape_4d,
+      tflite::testing::golden_4d, golden_quantized,
+      tflite::testing::output_scale_int16,
+      tflite::testing::output_zero_point_int16, output_data,
+      tflite::testing::tolerance_int16);
+}
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/vexriscv/README.md b/tensorflow/lite/micro/kernels/vexriscv/README.md
new file mode 100644
index 0000000..bba47df
--- /dev/null
+++ b/tensorflow/lite/micro/kernels/vexriscv/README.md
@@ -0,0 +1,39 @@
+# VexRISC-V
+
+## Maintainers
+
+*   [danielyou0230](https://github.com/danielyou0230)
+*   [tal-x](https://github.com/tcal-x)
+
+## Background
+
+The optimized kernels for
+[VexRISC-V](https://github.com/SpinalHDL/VexRiscv)/[Litex](https://github.com/enjoy-digital/litex)
+are used to run Tensorflow Lite Micro in Zephyr on either
+
+*   Digilent Arty board (e.g. Arty A7)
+*   [Renode](https://github.com/renode/renode): Open source simulation framework
+    (no hardware required)
+
+To run on Digilent Arty board (FPGA,) you'll also need a soft-CPU gateware for
+the FPGA, please see
+[Tensorflow lite demo running in Zephyr on Litex/VexRiscv SoC](https://github.com/antmicro/litex-vexriscv-tensorflow-lite-demo)
+by Antmicro for more details.
+
+## Info
+
+To use VexRISC-V optimized kernels instead of reference kernel add
+`TAGS=vexriscv` to the make command. The kernels that doesn't have optimization
+for a certain micro architecture fallback to use TFLM reference kernels.
+
+# Example
+
+To compile the binary file with VexRISC-V optimizations, one can use the
+following command
+
+```
+make -f tensorflow/lite/micro/tools/make/Makefile \
+TAGS=vexriscv \
+TARGET=zephyr_vexriscv \
+person_detection_int8_bin
+```
diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc
index 881b9b9..f279450 100644
--- a/tensorflow/lite/micro/micro_allocator.cc
+++ b/tensorflow/lite/micro/micro_allocator.cc
@@ -337,8 +337,8 @@
     current->bytes = handle->bytes;
     current->first_created = handle->node_idx;
     current->last_used = handle->node_idx;
-    current->needs_allocating = true;
     current->offline_offset = kOnlinePlannedBuffer;
+    current->needs_allocating = true;
   }
   return kTfLiteOk;
 }
@@ -655,6 +655,7 @@
 
   model_is_allocating_ = true;
 
+  TF_LITE_ENSURE_STATUS(InitScratchBufferHandles());
   TF_LITE_ENSURE_STATUS(AllocateTfLiteEvalTensors(model, eval_tensors));
   TF_LITE_ENSURE_STATUS(
       AllocateNodeAndRegistrations(model, node_and_registrations));
@@ -665,7 +666,8 @@
 }
 
 TfLiteStatus MicroAllocator::FinishModelAllocation(
-    const Model* model, TfLiteEvalTensor* eval_tensors) {
+    const Model* model, TfLiteEvalTensor* eval_tensors,
+    void** scratch_buffer_handles) {
   if (!model_is_allocating_) {
     TF_LITE_REPORT_ERROR(error_reporter_,
                          "MicroAllocator: Model allocation finished before "
@@ -676,9 +678,13 @@
   const SubGraph* subgraph = GetSubGraphFromModel(model);
   TFLITE_DCHECK(subgraph != nullptr);
 
+  TF_LITE_ENSURE_STATUS(MoveScratchBufferHandlesToTail());
   TF_LITE_ENSURE_STATUS(CommitStaticMemoryPlan(model, subgraph, eval_tensors));
   TF_LITE_ENSURE_STATUS(AllocateVariables(subgraph, eval_tensors));
 
+  if (scratch_buffer_handles != nullptr) {
+    *scratch_buffer_handles = scratch_buffer_handles_;
+  }
   model_is_allocating_ = false;
   return kTfLiteOk;
 }
@@ -690,49 +696,39 @@
 TfLiteStatus MicroAllocator::RequestScratchBufferInArena(int node_id,
                                                          size_t bytes,
                                                          int* buffer_idx) {
-  // A consistency check to make sure scratch_buffer_handles_ is contiguous i.e.
-  // scratch_buffer_handles_ is pointing to the last allocation from memory
-  // allocator.
-  if (scratch_buffer_handles_ != nullptr &&
-      reinterpret_cast<uint8_t*>(scratch_buffer_handles_) !=
-          memory_allocator_->GetTail()) {
-    TF_LITE_REPORT_ERROR(error_reporter_,
-                         "Internal error: AllocateFromTail can not be called "
-                         "between two RequestScratchBufferInArena calls.");
-    return kTfLiteError;
+  // This method is only called during Prepare stage, when the scratch buffer
+  // handles are placed in the head.
+
+  // Allocate space for the new scratch buffer handle.
+  TF_LITE_ENSURE_STATUS(memory_allocator_->EnsureHeadSize(
+      sizeof(internal::ScratchBufferHandle) * (scratch_buffer_count_ + 1),
+      alignof(internal::ScratchBufferHandle)));
+
+  if (scratch_buffer_handles_ == nullptr) {
+    // If this is the first scratch buffer handle, place it in the buffer head.
+    scratch_buffer_handles_ = reinterpret_cast<internal::ScratchBufferHandle*>(
+        memory_allocator_->GetBufferHead());
   }
 
+  // Initialize the handle. `data` field will be set during memory planning.
   internal::ScratchBufferHandle* handle =
-      reinterpret_cast<internal::ScratchBufferHandle*>(
-          memory_allocator_->AllocateFromTail(
-              sizeof(internal::ScratchBufferHandle),
-              alignof(internal::ScratchBufferHandle)));
-  if (handle == nullptr) {
-    TF_LITE_REPORT_ERROR(error_reporter_,
-                         "Failed to register scratch buffer handle for node %s",
-                         node_id);
-    return kTfLiteError;
-  }
+      scratch_buffer_handles_ + scratch_buffer_count_;
   *handle = {};
   handle->bytes = bytes;
   handle->node_idx = node_id;
+
+  // Buffer idx starts from 0 in this implementation.
   *buffer_idx = scratch_buffer_count_;
   scratch_buffer_count_ += 1;
-  // scratch_buffer_handles_ is in reverse order. The following code ensures
-  // that scratch_buffers[0] is pointing to the newly allocated handle.
-  scratch_buffer_handles_ = handle;
   return kTfLiteOk;
 }
 
-void* MicroAllocator::GetScratchBuffer(int buffer_idx) const {
-  if (static_cast<size_t>(buffer_idx) >= scratch_buffer_count_) {
-    TF_LITE_REPORT_ERROR(error_reporter_,
-                         "Buffer %d not found. %d buffers available.",
-                         buffer_idx, scratch_buffer_count_);
-    return nullptr;
-  }
-  // scratch_buffer_handles_ is in reverse order.
-  return scratch_buffer_handles_[scratch_buffer_count_ - buffer_idx - 1].data;
+void* MicroAllocator::GetScratchBuffer(void* scratch_buffer_handles,
+                                       int buffer_idx) {
+  internal::ScratchBufferHandle* handle =
+      reinterpret_cast<internal::ScratchBufferHandle*>(scratch_buffer_handles) +
+      buffer_idx;
+  return handle->data;
 }
 
 size_t MicroAllocator::used_bytes() const {
@@ -1035,7 +1031,6 @@
         builder.GetOfflinePlannedOffsets(model, &offline_planner_offsets));
     TF_LITE_ENSURE_STATUS(
         builder.AddTensors(subgraph, offline_planner_offsets, eval_tensors));
-
     TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_handles_));
     const AllocationInfo* allocation_info = builder.Finish();
 
@@ -1051,16 +1046,16 @@
 
     size_t actual_available_arena_size =
         memory_allocator_->GetAvailableMemory(kBufferAlignment);
+
     // Make sure we have enough arena size.
     if (planner.GetMaximumMemorySize() > actual_available_arena_size) {
       TF_LITE_REPORT_ERROR(
           error_reporter_,
-          "Arena size is too small for activation buffers. Needed %d but only "
-          "%d was available.",
+          "Arena size is too small for all buffers. Needed %u but only "
+          "%u was available.",
           planner.GetMaximumMemorySize(), actual_available_arena_size);
       return kTfLiteError;
     }
-
     // Commit the plan.
     TF_LITE_ENSURE_STATUS(CommitPlan(error_reporter_, &planner,
                                      memory_allocator_->GetBufferHead(),
@@ -1073,4 +1068,27 @@
   return kTfLiteOk;
 }
 
+TfLiteStatus MicroAllocator::InitScratchBufferHandles() {
+  scratch_buffer_count_ = 0;
+  scratch_buffer_handles_ = nullptr;
+  return kTfLiteOk;
+}
+
+TfLiteStatus MicroAllocator::MoveScratchBufferHandlesToTail() {
+  if (scratch_buffer_count_ == 0) {
+    return kTfLiteOk;
+  }
+  auto src = scratch_buffer_handles_;
+  internal::ScratchBufferHandle* dest =
+      reinterpret_cast<internal::ScratchBufferHandle*>(
+          memory_allocator_->AllocateFromTail(
+              sizeof(internal::ScratchBufferHandle) * scratch_buffer_count_,
+              alignof(internal::ScratchBufferHandle)));
+  for (size_t i = 0; i < scratch_buffer_count_; i++) {
+    *(dest + i) = *(src + i);
+  }
+  scratch_buffer_handles_ = dest;
+  return kTfLiteOk;
+}
+
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h
index efd11b8..5b47883 100644
--- a/tensorflow/lite/micro/micro_allocator.h
+++ b/tensorflow/lite/micro/micro_allocator.h
@@ -123,9 +123,12 @@
   // the 'head' section of the memory arena. All variable tensor data will also
   // be allocated. This method should be called after assigning model resources
   // in StartModelAllocation(). The eval_tensors pointer should be the value
-  // passed into this class during StartModelAllocation().
+  // passed into this class during StartModelAllocation(). Scratch buffer
+  // handles are stored in the out-param `scratch_buffer_handles`. This value
+  // will be used in `GetScratchBuffer` call to retrieve scratch buffers.
   TfLiteStatus FinishModelAllocation(const Model* model,
-                                     TfLiteEvalTensor* eval_tensors);
+                                     TfLiteEvalTensor* eval_tensors,
+                                     void** scratch_buffer_handles = nullptr);
 
   // Allocates a TfLiteTensor struct and populates the returned value with
   // properties from the model flatbuffer. This struct is allocated from
@@ -160,12 +163,18 @@
   // This method only allocates a BufferHandle holding information for memory
   // planning. The buffer ptr is ready after `FinishModelAllocation` and can
   // be retrieved by `GetScratchBuffer` method using the returned buffer_idx.
-  // Note that there should be no tail allocation between two consecutive
-  // `RequestScratchBufferInArena` calls.
+  // Note that this method should only be called in the Prepare stage.
   TfLiteStatus RequestScratchBufferInArena(int node_id, size_t bytes,
                                            int* buffer_idx);
-  // Returns the pointer to the planned scratch buffer.
-  void* GetScratchBuffer(int buffer_idx) const;
+
+  // Return the number of scratch buffers in the allocator.
+  size_t GetScratchBufferCount() const { return scratch_buffer_count_; }
+
+  // Return the pointer to the planned scratch buffer. `scratch_buffer_handles`
+  // should be the corresponding value returned in `FinishModelAllocation`.
+  // `scratch_buffer_handles` is intentionally desigend as void*. The actual
+  // data type is an implementation detail, and is only visible in this class.
+  static void* GetScratchBuffer(void* scratch_buffer_handles, int buffer_idx);
 
   // Returns the arena usage in bytes, only available after
   // `FinishModelAllocation`. Otherwise, it will return 0.
@@ -236,13 +245,16 @@
   ErrorReporter* error_reporter_;
   bool model_is_allocating_;
 
-  // In reverse order for efficiency.
-  // i.e. scratch_buffer_handles_[0] is the handle for the last buffer,
-  // corresponding to the last RequestScratchBufferInArena call.
+  // Points to the first allocated scratch buffer handle.
+  // Scratch buffer handles are placed in the head during `Prepare` stage and
+  // then moved to the tail for static memory plan.
   internal::ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
   // How many scratch buffers have been allocated.
   size_t scratch_buffer_count_ = 0;
 
+  virtual TfLiteStatus InitScratchBufferHandles();
+  virtual TfLiteStatus MoveScratchBufferHandlesToTail();
+
   TF_LITE_REMOVE_VIRTUAL_DELETE
 };
 
diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc
index 8c2f8e0..f4f591b 100644
--- a/tensorflow/lite/micro/micro_interpreter.cc
+++ b/tensorflow/lite/micro/micro_interpreter.cc
@@ -59,13 +59,31 @@
                                                         size_t bytes,
                                                         int* buffer_idx) {
   ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
-  return helper->allocator_->RequestScratchBufferInArena(
-      helper->current_node_idx_, bytes, buffer_idx);
+
+  // We can not forward the scratch buffer request to the allocator yet,
+  // otherwise the scratch buffer handles will ruin the data in `temp` section.
+  // These requests will be processed once the `temp` section is deallocated,
+  // i.e. after a node has been prepared.
+
+  if (helper->scratch_buffer_count_ >= kMaxScratchBuffersPerOp) {
+    TF_LITE_REPORT_ERROR(
+        helper->error_reporter_,
+        "Node %d is allocating too many scratch buffers per op, max=%d",
+        helper->current_node_idx_, helper->scratch_buffer_count_);
+  }
+  helper->scrach_buffer_sizes_[helper->scratch_buffer_count_] = bytes;
+  // buffer_idx is 0 indexed.
+  *buffer_idx = helper->scratch_buffer_count_ +
+                helper->allocator_->GetScratchBufferCount();
+  helper->scratch_buffer_count_++;
+  return kTfLiteOk;
 }
 
 void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
-  return reinterpret_cast<ContextHelper*>(ctx->impl_)
-      ->allocator_->GetScratchBuffer(buffer_idx);
+  ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
+
+  return helper->allocator_->GetScratchBuffer(helper->scratch_buffer_handles_,
+                                              buffer_idx);
 }
 
 void ContextHelper::ReportOpError(struct TfLiteContext* context,
@@ -92,12 +110,39 @@
   return &helper->eval_tensors_[tensor_idx];
 }
 
-void ContextHelper::SetNodeIndex(int idx) { current_node_idx_ = idx; }
+void ContextHelper::SetNodeIndex(int idx) {
+  if (scratch_buffer_count_ != 0) {
+    TF_LITE_REPORT_ERROR(error_reporter_,
+                         "Internal error: Please commit scratch buffers "
+                         "befrore moving to the next node");
+  }
+  current_node_idx_ = idx;
+}
 
 void ContextHelper::SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors) {
   eval_tensors_ = eval_tensors;
 }
 
+void ContextHelper::SetScratchBufferHandles(void* scratch_buffer_handle) {
+  scratch_buffer_handles_ = scratch_buffer_handle;
+}
+
+TfLiteStatus ContextHelper::CommitScratchBuffers() {
+  size_t initial_buffer_count = allocator_->GetScratchBufferCount();
+  for (size_t i = 0; i < scratch_buffer_count_; i++) {
+    int buffer_id;
+    allocator_->RequestScratchBufferInArena(
+        current_node_idx_, scrach_buffer_sizes_[i], &buffer_id);
+    if (static_cast<size_t>(buffer_id) != initial_buffer_count + i) {
+      TF_LITE_REPORT_ERROR(
+          error_reporter_,
+          "Internal error. Scratch buffers are not contiguous.\n");
+    }
+  }
+  scratch_buffer_count_ = 0;
+  return kTfLiteOk;
+}
+
 }  // namespace internal
 
 MicroInterpreter::MicroInterpreter(const Model* model,
@@ -297,6 +342,7 @@
       }
     }
     allocator_.ResetTempAllocations();
+    context_helper_.CommitScratchBuffers();
   }
   context_helper_.SetNodeIndex(-1);
 
@@ -306,8 +352,12 @@
   context_.RequestScratchBufferInArena = nullptr;
   context_.GetScratchBuffer = context_helper_.GetScratchBuffer;
 
+  void* scratch_buffer_handles = nullptr;
+
   TF_LITE_ENSURE_OK(&context_,
-                    allocator_.FinishModelAllocation(model_, eval_tensors_));
+                    allocator_.FinishModelAllocation(model_, eval_tensors_,
+                                                     &scratch_buffer_handles));
+  context_helper_.SetScratchBufferHandles(scratch_buffer_handles);
   TF_LITE_ENSURE_STATUS(ResetVariableTensors());
 
   tensors_allocated_ = true;
diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h
index 0983a00..f36d9d8 100644
--- a/tensorflow/lite/micro/micro_interpreter.h
+++ b/tensorflow/lite/micro/micro_interpreter.h
@@ -32,6 +32,8 @@
 
 namespace internal {
 
+constexpr size_t kMaxScratchBuffersPerOp = 8;
+
 // A helper class to encapsulate the implementation of APIs in Context.
 // context->impl_ points to an instance of this class.
 // Check tensorflow/lite/c/common.h for detailed descriptions.
@@ -53,19 +55,28 @@
                                  int tensor_idx);
   static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
                                          int tensor_idx);
+  // Commits all scratch buffer allocations to MicroAllocator.
+  TfLiteStatus CommitScratchBuffers();
 
   // Sets the current node index to assist with scratch buffer allocations:
   void SetNodeIndex(int idx);
 
   // Sets the pointer to a list of TfLiteEvalTensor instances.
   void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors);
+  // Sets the pointer to scratch buffer handle, which is needed by
+  // `GetScratchBuffer`.
+  void SetScratchBufferHandles(void* scratch_buffer_handle);
 
  private:
-  MicroAllocator* allocator_;
-  ErrorReporter* error_reporter_;
-  const Model* model_;
-  TfLiteEvalTensor* eval_tensors_;
+  MicroAllocator* allocator_ = nullptr;
+  ErrorReporter* error_reporter_ = nullptr;
+  const Model* model_ = nullptr;
+  TfLiteEvalTensor* eval_tensors_ = nullptr;
+  void* scratch_buffer_handles_ = nullptr;
   int current_node_idx_ = -1;
+
+  size_t scrach_buffer_sizes_[kMaxScratchBuffersPerOp];
+  size_t scratch_buffer_count_ = 0;
 };
 
 }  // namespace internal
diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc
index 150dbea..a4a4143 100644
--- a/tensorflow/lite/micro/micro_interpreter_test.cc
+++ b/tensorflow/lite/micro/micro_interpreter_test.cc
@@ -220,38 +220,45 @@
 
   tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver();
 
-  constexpr size_t allocator_buffer_size = 2048;
+  constexpr size_t allocator_buffer_size = 4096;
   uint8_t allocator_buffer[allocator_buffer_size];
-  tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer,
-                                       allocator_buffer_size,
-                                       micro_test::reporter);
-  TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
-  TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.inputs_size());
-  TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(2), interpreter.outputs_size());
 
-  TfLiteTensor* input = interpreter.input(0);
-  TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
-  TF_LITE_MICRO_EXPECT_EQ(3, input->dims->data[0]);
-  input->data.uint8[0] = 2;
-  input->data.uint8[1] = 3;
-  input->data.uint8[2] = 1;
+  tflite::RecordingMicroAllocator* allocator =
+      tflite::RecordingMicroAllocator::Create(
+          allocator_buffer, allocator_buffer_size, micro_test::reporter);
 
-  uint8_t expected_median = 2;
+  // Make sure kernel memory planning works in multi-tenant context.
+  for (int i = 0; i < 3; i++) {
+    tflite::MicroInterpreter interpreter(model, op_resolver, allocator,
+                                         micro_test::reporter);
+    TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.inputs_size());
+    TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(2), interpreter.outputs_size());
 
-  {
-    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
-    TfLiteTensor* median = interpreter.output(0);
-    TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
-    TfLiteTensor* invoke_count = interpreter.output(1);
-    TF_LITE_MICRO_EXPECT_EQ(1, invoke_count->data.i32[0]);
-  }
+    TfLiteTensor* input = interpreter.input(0);
+    TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
+    TF_LITE_MICRO_EXPECT_EQ(3, input->dims->data[0]);
+    input->data.uint8[0] = 2;
+    input->data.uint8[1] = 3;
+    input->data.uint8[2] = 1;
 
-  {
-    TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
-    TfLiteTensor* median = interpreter.output(0);
-    TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
-    TfLiteTensor* invoke_count = interpreter.output(1);
-    TF_LITE_MICRO_EXPECT_EQ(2, invoke_count->data.i32[0]);
+    uint8_t expected_median = 2;
+
+    {
+      TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
+      TfLiteTensor* median = interpreter.output(0);
+      TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
+      TfLiteTensor* invoke_count = interpreter.output(1);
+      TF_LITE_MICRO_EXPECT_EQ(1, invoke_count->data.i32[0]);
+    }
+
+    {
+      TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
+      TfLiteTensor* median = interpreter.output(0);
+      TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
+      TfLiteTensor* invoke_count = interpreter.output(1);
+      TF_LITE_MICRO_EXPECT_EQ(2, invoke_count->data.i32[0]);
+    }
   }
 }
 
diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc
index 23c7ca9..6a2d981 100644
--- a/tensorflow/lite/micro/test_helpers.cc
+++ b/tensorflow/lite/micro/test_helpers.cc
@@ -593,13 +593,18 @@
   TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
       context, sizeof(uint8_t) * NumElements(input->dims),
       &data->sorting_buffer));
+  // We can interleave scratch / persistent buffer allocation.
+  data->invoke_count = reinterpret_cast<int*>(
+      context->AllocatePersistentBuffer(context, sizeof(int)));
+  *data->invoke_count = 0;
+
   return kTfLiteOk;
 }
 
 TfLiteStatus SimpleStatefulOp::Invoke(TfLiteContext* context,
                                       TfLiteNode* node) {
   OpData* data = reinterpret_cast<OpData*>(node->user_data);
-  data->invoke_count += 1;
+  *data->invoke_count += 1;
 
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
   const uint8_t* input_data = GetTensorData<uint8_t>(input);
@@ -626,7 +631,7 @@
   int32_t* invoke_count_data = GetTensorData<int32_t>(invoke_count);
 
   median_data[0] = sorting_buffer[size / 2];
-  invoke_count_data[0] = data->invoke_count;
+  invoke_count_data[0] = *data->invoke_count;
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h
index a789714..e41beb9 100644
--- a/tensorflow/lite/micro/test_helpers.h
+++ b/tensorflow/lite/micro/test_helpers.h
@@ -49,7 +49,7 @@
   static constexpr int kMedianTensor = 0;
   static constexpr int kInvokeCount = 1;
   struct OpData {
-    int invoke_count = 0;
+    int* invoke_count = nullptr;
     int sorting_buffer = kBufferNotAllocated;
   };
 
diff --git a/tensorflow/lite/micro/testing/test_linux_binary.sh b/tensorflow/lite/micro/testing/test_linux_binary.sh
index 1e967be..30cf041 100755
--- a/tensorflow/lite/micro/testing/test_linux_binary.sh
+++ b/tensorflow/lite/micro/testing/test_linux_binary.sh
@@ -1,4 +1,4 @@
-#!/bin/bash -e
+#!/bin/bash
 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile
index 377301d..b896820 100644
--- a/tensorflow/lite/micro/tools/make/Makefile
+++ b/tensorflow/lite/micro/tools/make/Makefile
@@ -106,7 +106,11 @@
 # These two must be defined before we include the target specific Makefile.inc
 # because we filter out the examples that are not supported for those targets.
 # See targets/xtensa_xpg_makefile.inc for an example.
-MICRO_LITE_EXAMPLE_TESTS := $(shell find tensorflow/lite/micro/examples/ -name Makefile.inc)
+# We limit max depth of directories to search to do not include
+# target specific Makefiles that are included directly by the
+# main example Makefile.
+# See examples/micro_speech/Makefile.inc for an example.
+MICRO_LITE_EXAMPLE_TESTS := $(shell find tensorflow/lite/micro/examples/ -maxdepth 2 -name Makefile.inc)
 MICRO_LITE_BENCHMARKS := $(wildcard tensorflow/lite/micro/benchmarks/Makefile.inc)
 
 MICROLITE_TEST_SRCS := \
@@ -176,6 +180,7 @@
 tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \
 tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h \
 tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h \
+tensorflow/lite/kernels/internal/reference/integer_ops/mean.h \
 tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \
 tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \
 tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h \
@@ -211,8 +216,11 @@
 tensorflow/lite/schema/schema_generated.h \
 tensorflow/lite/version.h
 
+# TODO(b/165940489): Figure out how to avoid including fixed point
+# platform-specific headers.
 THIRD_PARTY_CC_HDRS := \
 third_party/gemmlowp/fixedpoint/fixedpoint.h \
+third_party/gemmlowp/fixedpoint/fixedpoint_neon.h \
 third_party/gemmlowp/fixedpoint/fixedpoint_sse.h \
 third_party/gemmlowp/internal/detect_platform.h \
 third_party/gemmlowp/LICENSE \
@@ -256,6 +264,8 @@
 $(eval $(call add_third_party_download,$(GEMMLOWP_URL),$(GEMMLOWP_MD5),gemmlowp,))
 $(eval $(call add_third_party_download,$(FLATBUFFERS_URL),$(FLATBUFFERS_MD5),flatbuffers,))
 $(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,))
+$(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,))
+$(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,))
 
 # These target-specific makefiles should modify or replace options like
 # CXXFLAGS or LIBS to work for a specific targeted architecture. All logic
diff --git a/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc b/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc
index e46ca07..3bbe6f9 100644
--- a/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc
+++ b/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc
@@ -1,10 +1,19 @@
 # Settings for Hexagon toolchain.
 # REQUIRED:
-#   - Hexagon SDK 3.5 Toolkit (for hexagon-clang++, hexagon-sim).
-#   - HEXAGON_SDK_PREFIX environment variable must be set to location of
+#   - Hexagon SDK 3.5 Toolkit (for qurt, posix libs).
+#     HEXAGON_SDK_ROOT environment variable must be set to location of
 #     Hexagon_SDK/<version>/ on your machine.
+#   - Hexagon Tools root (for hexagon-clang++, hexagon-sim).
+#     The tool folder may be a part of the Hexagon SDK
+#      (e.g. $(HEXAGON_SDK_ROOT)/tools/HEXAGON_Tools) or installed
+#       separately.
+#     HEXAGON_ROOT environment variable must be set to location of
+#     HEXAGON_Tools on your machine.
+#   - HEXAGON_TOOL_VER: The Hexagon tool version (installed under HEXAGON_ROOT).
+#      For example: 8.3.07
 #   - HEXAGON_CPU_VER: The CPU version to use, will cause a compiler exception
-#                  without providing a version. Acceptable values: v55-v67
+#      without providing a version. Valid values may vary depending on tools
+#      version, but generally in the range: v55-v67
 #
 # Unlike other targets, there is not currently a way to automatically download
 # the Hexagon SDK.  For this reason, users are required to manually download
@@ -12,8 +21,16 @@
 ifeq ($(TARGET), hexagon)
   TARGET_ARCH := hexagon
 
-  ifndef HEXAGON_SDK_PREFIX
-    $(error HEXAGON_SDK_PREFIX is undefined)
+  ifndef HEXAGON_SDK_ROOT
+    $(error HEXAGON_SDK_ROOT is undefined)
+  endif
+
+  ifndef HEXAGON_TOOL_VER
+    $(error HEXAGON_TOOL_VER is undefined)
+  endif
+
+  ifndef HEXAGON_ROOT
+    $(error HEXAGON_ROOT is undefined)
   endif
 
   ifndef HEXAGON_CPU_VER
@@ -55,6 +72,7 @@
     -mcpu=$(HEXAGON_CPU_VER) \
     -m$(HEXAGON_CPU_VER)
 
+  export PATH := $(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/bin:$(PATH)
   TARGET_TOOLCHAIN_PREFIX := hexagon-
   CXX_TOOL := clang++
   CC_TOOL := clang
@@ -63,11 +81,11 @@
   CCFLAGS += $(PLATFORM_ARGS)
   LDFLAGS += \
     -Wl,--gc-sections -lhexagon \
-    $(HEXAGON_SDK_PREFIX)/tools/HEXAGON_Tools/8.3.07/Tools/target/hexagon/lib/v66/libstdc++.a
+    $(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/target/hexagon/lib/v66/libstdc++.a
 
   INCLUDES += \
-    -I$(HEXAGON_SDK_PREFIX)/libs/common/qurt/computev66/include/posix \
-    -I$(HEXAGON_SDK_PREFIX)/libs/common/qurt/computev66/include/qurt
+    -I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/posix \
+    -I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/qurt
 
   TEST_SCRIPT := tensorflow/lite/micro/testing/test_hexagon_binary.sh
 endif
diff --git a/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD
index 82d775d..8d08c22 100644
--- a/tensorflow/lite/nnapi/BUILD
+++ b/tensorflow/lite/nnapi/BUILD
@@ -1,4 +1,5 @@
 load("//tensorflow/lite:special_rules.bzl", "if_nnapi")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = [
@@ -13,6 +14,7 @@
         "NeuralNetworksShim.h",
         "NeuralNetworksTypes.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     linkopts = if_nnapi(["-ldl"]),
 )
 
@@ -25,6 +27,7 @@
     hdrs = [
         "nnapi_implementation.h",
     ],
+    compatible_with = get_compatible_with_portable(),
     linkopts = if_nnapi(["-ldl"]) + if_nnapi(
         supported = ["-lrt"],
         supported_android = [],
@@ -38,6 +41,7 @@
     name = "nnapi_util",
     srcs = ["nnapi_util.cc"],
     hdrs = ["nnapi_util.h"],
+    compatible_with = get_compatible_with_portable(),
     deps = [
         ":nnapi_implementation",
         "//tensorflow/lite:util",
diff --git a/tensorflow/lite/profiling/BUILD b/tensorflow/lite/profiling/BUILD
index ac95759..b54e742 100644
--- a/tensorflow/lite/profiling/BUILD
+++ b/tensorflow/lite/profiling/BUILD
@@ -1,13 +1,14 @@
+load("//tensorflow:tensorflow.bzl", "if_not_windows")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 load("//tensorflow/lite:build_def.bzl", "tflite_copts")
+load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite_combined")
 
 package(
     default_visibility = ["//visibility:public"],
     licenses = ["notice"],  # Apache 2.0
 )
 
-common_copts = [
-    "-Wall",
-] + tflite_copts()
+common_copts = tflite_copts() + if_not_windows(["-Wall"])
 
 cc_library(
     name = "profiler",
@@ -23,6 +24,16 @@
     ],
 )
 
+cc_test(
+    name = "profiler_test",
+    srcs = ["profiler_test.cc"],
+    deps = [
+        ":profiler",
+        ":test_main",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_library(
     name = "atrace_profiler",
     srcs = ["atrace_profiler.cc"],
@@ -35,10 +46,21 @@
     ],
 )
 
+cc_test(
+    name = "atrace_profiler_test",
+    srcs = ["atrace_profiler_test.cc"],
+    deps = [
+        ":atrace_profiler",
+        ":test_main",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_library(
     name = "platform_profiler",
     srcs = ["platform_profiler.cc"],
     hdrs = ["platform_profiler.h"],
+    compatible_with = get_compatible_with_portable(),
     copts = common_copts,
     deps = [
         "//tensorflow/lite/core/api",
@@ -48,16 +70,6 @@
     }),
 )
 
-cc_test(
-    name = "profiler_test",
-    srcs = ["profiler_test.cc"],
-    deps = [
-        ":profiler",
-        "//tensorflow/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
-
 cc_library(
     name = "profile_buffer",
     hdrs = ["profile_buffer.h"],
@@ -69,6 +81,16 @@
     ],
 )
 
+cc_test(
+    name = "profile_buffer_test",
+    srcs = ["profile_buffer_test.cc"],
+    deps = [
+        ":profile_buffer",
+        ":test_main",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_library(
     name = "time",
     srcs = ["time.cc"],
@@ -76,6 +98,16 @@
     copts = common_copts,
 )
 
+cc_test(
+    name = "time_test",
+    srcs = ["time_test.cc"],
+    deps = [
+        ":test_main",
+        ":time",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 cc_library(
     name = "memory_info",
     srcs = ["memory_info.cc"],
@@ -84,30 +116,20 @@
 )
 
 cc_test(
-    name = "time_test",
-    srcs = ["time_test.cc"],
-    copts = common_copts,
-    deps = [
-        ":time",
-        "//tensorflow/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
-
-cc_test(
     name = "memory_info_test",
     srcs = ["memory_info_test.cc"],
-    copts = common_copts,
     tags = [
         # Some low-level checks, like heap size check, may break in asan, msan
         # and tsan. So, disable such tests.
         "noasan",
         "nomsan",
         "notsan",
+        # TODO(b/166227284): Fix the test for Android.
+        "tflite_not_portable_android",
     ],
     deps = [
         ":memory_info",
-        "//tensorflow/lite/testing:util",
+        ":test_main",
         "@com_google_googletest//:gtest",
     ],
 )
@@ -125,10 +147,9 @@
 cc_test(
     name = "profile_summary_formatter_test",
     srcs = ["profile_summary_formatter_test.cc"],
-    copts = common_copts,
     deps = [
         ":profile_summary_formatter",
-        "//tensorflow/lite/testing:util",
+        ":test_main",
         "@com_google_googletest//:gtest",
     ],
 )
@@ -151,26 +172,28 @@
 cc_test(
     name = "profile_summarizer_test",
     srcs = ["profile_summarizer_test.cc"],
-    copts = common_copts,
     deps = [
         ":profile_summarizer",
         ":profiler",
+        ":test_main",
         "//tensorflow/lite:framework",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite/kernels:kernel_util",
         "//tensorflow/lite/kernels:subgraph_test_util",
         "//tensorflow/lite/kernels:test_util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_library(
+    name = "test_main",
+    testonly = 1,
+    srcs = ["test_main.cc"],
+    visibility = ["//visibility:private"],
+    deps = [
         "//tensorflow/lite/testing:util",
         "@com_google_googletest//:gtest",
     ],
 )
 
-cc_test(
-    name = "profile_buffer_test",
-    srcs = ["profile_buffer_test.cc"],
-    deps = [
-        ":profile_buffer",
-        "//tensorflow/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
+tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
diff --git a/tensorflow/lite/profiling/atrace_profiler.cc b/tensorflow/lite/profiling/atrace_profiler.cc
index 4bdaf9d..cc29c2d 100644
--- a/tensorflow/lite/profiling/atrace_profiler.cc
+++ b/tensorflow/lite/profiling/atrace_profiler.cc
@@ -15,6 +15,9 @@
 #include "tensorflow/lite/profiling/atrace_profiler.h"
 
 #include <dlfcn.h>
+#if defined(__ANDROID__)
+#include <sys/system_properties.h>
+#endif
 
 #include <type_traits>
 
@@ -89,8 +92,16 @@
   FpEndSection atrace_end_section_;
 };
 
-std::unique_ptr<tflite::Profiler> CreateATraceProfiler() {
-  return std::unique_ptr<tflite::Profiler>(new ATraceProfiler());
+std::unique_ptr<tflite::Profiler> MaybeCreateATraceProfiler() {
+#if defined(__ANDROID__)
+  constexpr char kTraceProp[] = "debug.tflite.trace";
+  char trace_enabled[PROP_VALUE_MAX] = "";
+  int length = __system_property_get(kTraceProp, trace_enabled);
+  if (length == 1 && trace_enabled[0] == '1') {
+    return std::unique_ptr<tflite::Profiler>(new ATraceProfiler());
+  }
+#endif  // __ANDROID__
+  return nullptr;
 }
 
 }  // namespace profiling
diff --git a/tensorflow/lite/profiling/atrace_profiler.h b/tensorflow/lite/profiling/atrace_profiler.h
index d103cbc..044db1c 100644
--- a/tensorflow/lite/profiling/atrace_profiler.h
+++ b/tensorflow/lite/profiling/atrace_profiler.h
@@ -22,7 +22,7 @@
 namespace tflite {
 namespace profiling {
 
-std::unique_ptr<tflite::Profiler> CreateATraceProfiler();
+std::unique_ptr<tflite::Profiler> MaybeCreateATraceProfiler();
 
 }  // namespace profiling
 }  // namespace tflite
diff --git a/tensorflow/lite/profiling/atrace_profiler_test.cc b/tensorflow/lite/profiling/atrace_profiler_test.cc
new file mode 100644
index 0000000..d2a5c52
--- /dev/null
+++ b/tensorflow/lite/profiling/atrace_profiler_test.cc
@@ -0,0 +1,48 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/profiling/atrace_profiler.h"
+
+#if defined(__ANDROID__)
+#include <sys/system_properties.h>
+#endif
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace profiling {
+
+namespace {
+
+TEST(ATraceProfilerTest, MaybeCreateATraceProfiler) {
+  auto default_profiler = MaybeCreateATraceProfiler();
+  EXPECT_EQ(nullptr, default_profiler.get());
+
+#if defined(__ANDROID__)
+  if (__system_property_set("debug.tflite.trace", "1") == 0) {
+    auto profiler = MaybeCreateATraceProfiler();
+    EXPECT_NE(nullptr, profiler.get());
+  }
+
+  if (__system_property_set("debug.tflite.trace", "0") == 0) {
+    auto no_profiler = MaybeCreateATraceProfiler();
+    EXPECT_EQ(nullptr, no_profiler.get());
+  }
+#endif  // __ANDROID__
+}
+
+}  // namespace
+}  // namespace profiling
+}  // namespace tflite
diff --git a/tensorflow/lite/profiling/memory_info_test.cc b/tensorflow/lite/profiling/memory_info_test.cc
index a6bd2e4..9b580b7 100644
--- a/tensorflow/lite/profiling/memory_info_test.cc
+++ b/tensorflow/lite/profiling/memory_info_test.cc
@@ -15,7 +15,6 @@
 #include "tensorflow/lite/profiling/memory_info.h"
 
 #include <gtest/gtest.h>
-#include "tensorflow/lite/testing/util.h"
 
 namespace tflite {
 namespace profiling {
@@ -71,9 +70,3 @@
 }  // namespace memory
 }  // namespace profiling
 }  // namespace tflite
-
-int main(int argc, char** argv) {
-  ::tflite::LogToStderr();
-  ::testing::InitGoogleTest(&argc, argv);
-  return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/lite/profiling/platform_profiler.cc b/tensorflow/lite/profiling/platform_profiler.cc
index cd0770c..6ee290c 100644
--- a/tensorflow/lite/profiling/platform_profiler.cc
+++ b/tensorflow/lite/profiling/platform_profiler.cc
@@ -25,11 +25,11 @@
 namespace tflite {
 namespace profiling {
 
-std::unique_ptr<tflite::Profiler> CreatePlatformProfiler() {
+std::unique_ptr<tflite::Profiler> MaybeCreatePlatformProfiler() {
 #if defined(__ANDROID__)
-  return CreateATraceProfiler();
+  return MaybeCreateATraceProfiler();
 #else
-  return std::unique_ptr<tflite::Profiler>(nullptr);
+  return nullptr;
 #endif
 }
 
diff --git a/tensorflow/lite/profiling/platform_profiler.h b/tensorflow/lite/profiling/platform_profiler.h
index 87361b3..52a51f8 100644
--- a/tensorflow/lite/profiling/platform_profiler.h
+++ b/tensorflow/lite/profiling/platform_profiler.h
@@ -22,7 +22,7 @@
 namespace tflite {
 namespace profiling {
 
-std::unique_ptr<tflite::Profiler> CreatePlatformProfiler();
+std::unique_ptr<tflite::Profiler> MaybeCreatePlatformProfiler();
 
 }  // namespace profiling
 }  // namespace tflite
diff --git a/tensorflow/lite/profiling/profile_buffer_test.cc b/tensorflow/lite/profiling/profile_buffer_test.cc
index ab98cbb..457b6ff 100644
--- a/tensorflow/lite/profiling/profile_buffer_test.cc
+++ b/tensorflow/lite/profiling/profile_buffer_test.cc
@@ -20,7 +20,6 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
-#include "tensorflow/lite/testing/util.h"
 
 namespace tflite {
 namespace profiling {
@@ -121,9 +120,3 @@
 }  // namespace
 }  // namespace profiling
 }  // namespace tflite
-
-int main(int argc, char** argv) {
-  ::tflite::LogToStderr();
-  ::testing::InitGoogleTest(&argc, argv);
-  return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc
index 98d2619..fd81c00 100644
--- a/tensorflow/lite/profiling/profile_summarizer_test.cc
+++ b/tensorflow/lite/profiling/profile_summarizer_test.cc
@@ -26,7 +26,6 @@
 #include "tensorflow/lite/kernels/test_util.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/profiling/buffered_profiler.h"
-#include "tensorflow/lite/testing/util.h"
 #include "tensorflow/lite/version.h"
 
 namespace tflite {
@@ -224,9 +223,3 @@
 }  // namespace
 }  // namespace profiling
 }  // namespace tflite
-
-int main(int argc, char** argv) {
-  ::tflite::LogToStderr();
-  ::testing::InitGoogleTest(&argc, argv);
-  return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/lite/profiling/profile_summary_formatter_test.cc b/tensorflow/lite/profiling/profile_summary_formatter_test.cc
index 78d46aa..0de0e73 100644
--- a/tensorflow/lite/profiling/profile_summary_formatter_test.cc
+++ b/tensorflow/lite/profiling/profile_summary_formatter_test.cc
@@ -19,7 +19,6 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
-#include "tensorflow/lite/testing/util.h"
 
 namespace tflite {
 namespace profiling {
@@ -156,9 +155,3 @@
 }  // namespace
 }  // namespace profiling
 }  // namespace tflite
-
-int main(int argc, char** argv) {
-  ::tflite::LogToStderr();
-  ::testing::InitGoogleTest(&argc, argv);
-  return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/lite/profiling/profiler_test.cc b/tensorflow/lite/profiling/profiler_test.cc
index 1d8455e..c59dca9 100644
--- a/tensorflow/lite/profiling/profiler_test.cc
+++ b/tensorflow/lite/profiling/profiler_test.cc
@@ -22,7 +22,6 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
-#include "tensorflow/lite/testing/util.h"
 
 namespace tflite {
 namespace profiling {
@@ -136,9 +135,3 @@
 }  // namespace
 }  // namespace profiling
 }  // namespace tflite
-
-int main(int argc, char** argv) {
-  ::tflite::LogToStderr();
-  ::testing::InitGoogleTest(&argc, argv);
-  return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc b/tensorflow/lite/profiling/test_main.cc
similarity index 67%
copy from tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc
copy to tensorflow/lite/profiling/test_main.cc
index 44ce384..df6b8cb 100644
--- a/tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc
+++ b/tensorflow/lite/profiling/test_main.cc
@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -12,8 +12,12 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/testing/util.h"
 
-#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
-
-// Static initialization for TensorFlow.js op registration.
-static mlir::DialectRegistration<mlir::tfjs::TFJSDialect> tfjs_ops;
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/profiling/time_test.cc b/tensorflow/lite/profiling/time_test.cc
index 6f08479..8a85de9 100644
--- a/tensorflow/lite/profiling/time_test.cc
+++ b/tensorflow/lite/profiling/time_test.cc
@@ -15,7 +15,6 @@
 
 #include "tensorflow/lite/profiling/time.h"
 #include <gtest/gtest.h>
-#include "tensorflow/lite/testing/util.h"
 
 namespace tflite {
 namespace profiling {
@@ -48,9 +47,3 @@
 }  // namespace time
 }  // namespace profiling
 }  // namespace tflite
-
-int main(int argc, char** argv) {
-  ::tflite::LogToStderr();
-  ::testing::InitGoogleTest(&argc, argv);
-  return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py
index 8a9f892..7562337 100644
--- a/tensorflow/lite/python/convert.py
+++ b/tensorflow/lite/python/convert.py
@@ -428,10 +428,9 @@
   if conversion_summary_dir:
     toco.conversion_summary_dir = conversion_summary_dir
   if target_ops:
-    if set(target_ops) == set([OpsSet.TFLITE_BUILTINS, OpsSet.SELECT_TF_OPS]):
+    if OpsSet.SELECT_TF_OPS in set(target_ops):
       toco.enable_select_tf_ops = True
-    elif set(target_ops) == set([OpsSet.SELECT_TF_OPS]):
-      toco.enable_select_tf_ops = True
+    if set(target_ops) == set([OpsSet.SELECT_TF_OPS]):
       toco.force_select_tf_ops = True
 
   model = _model_flags_pb2.ModelFlags()
diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py
index d0ee2db..c7f86c6 100644
--- a/tensorflow/lite/python/interpreter.py
+++ b/tensorflow/lite/python/interpreter.py
@@ -542,25 +542,26 @@
     return self._interpreter.ResetVariableTensors()
 
   # Experimental and subject to change.
-  def _native_interpreter(self):
-    """Returns the underlying InterpreterWrapper object.
+  def _native_handle(self):
+    """Returns a pointer to the underlying tflite::Interpreter instance.
 
-    This allows users to extend tflite.Interpreter's functionality in custom cpp
-    function. For example,
-    at cpp level:
-      void SomeNewFeature(InterpreterWrapper* wrapper) {
-        // Get access to tflite::Interpreter
-        auto* interpreter = wrapper->interpreter();
-        // ...
-      }
-    at python level:
-      def some_new_feature(interpreter):
-        _cpp_to_py_wrapper.SomeNewFeature(interpreter._native_interpreter())
+    This allows extending tflite.Interpreter's functionality in a custom C++
+    function. Consider how that may work in a custom pybind wrapper:
+
+      m.def("SomeNewFeature", ([](py::object handle) {
+        auto* interpreter =
+          reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>());
+        ...
+      }))
+
+    and corresponding Python call:
+
+      SomeNewFeature(interpreter.native_handle())
 
     Note: This approach is fragile. Users must guarantee the C++ extension build
     is consistent with the tflite.Interpreter's underlying C++ build.
     """
-    return self._interpreter
+    return self._interpreter.interpreter()
 
 
 class InterpreterWithCustomOps(Interpreter):
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc
index f30912c..61771ff 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc
@@ -181,5 +181,8 @@
           },
           R"pbdoc(
              ask the interpreter to set the number of threads to use.
-          )pbdoc");
+          )pbdoc")
+      .def("interpreter", [](InterpreterWrapper& self) {
+        return reinterpret_cast<intptr_t>(self.interpreter());
+      });
 }
diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index 0cd7d25..4a0ae9d 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -548,33 +548,7 @@
 
 
 class TFLiteConverterBaseV2(TFLiteConverterBase):
-  """Converter subclass to share functionality between V2 converters.
-
-  Attributes:
-    allow_custom_ops: Boolean indicating whether to allow custom operations.
-      When False, any unknown operation is an error. When True, custom ops are
-      created for any op that is unknown. The developer needs to provide these
-      to the TensorFlow Lite runtime with a custom resolver. (default False)
-    optimizations: Experimental flag, subject to change. A list of optimizations
-      to apply when converting the model. E.g. `[Optimize.DEFAULT]`
-    representative_dataset: A representative dataset that can be used to
-      generate input and output samples for the model. The converter can use the
-      dataset to evaluate different optimizations. Note that this is an optional
-      attribute but it is necessary if INT8 is the only support builtin ops in
-      target ops.
-    target_spec: Experimental flag, subject to change. Specification of target
-      device.
-    inference_input_type: Data type of the input layer. Note that integer types
-      (tf.int8 and tf.uint8) are currently only supported for post training
-      integer quantization. (default tf.float32, must be in {tf.float32,
-      tf.int8, tf.uint8})
-    inference_output_type: Data type of the output layer. Note that integer
-      types (tf.int8 and tf.uint8) are currently only supported for post
-      training integer quantization. (default tf.float32, must be in
-      {tf.float32, tf.int8, tf.uint8})
-    experimental_new_converter: Experimental flag, subject to change. Enables
-      MLIR-based conversion instead of TOCO conversion. (default True)
-  """
+  """Converter subclass to share functionality between V2 converters."""
 
   def __init__(self):
     """Constructor for TFLiteConverter."""
@@ -1119,78 +1093,7 @@
 
 
 class TFLiteConverterBaseV1(TFLiteConverterBase):
-  """Converter subclass to share functionality between V1 converters.
-
-  Attributes:
-    inference_type: Target data type of real-number arrays in the output file.
-      Must be `{tf.float32, tf.uint8}`. If `optimzations` are provided, this
-      parameter is ignored. (default tf.float32)
-    inference_input_type: Target data type of real-number input arrays. Allows
-      for a different type for input arrays. If an integer type is provided and
-      `optimizations` are not used, `quantized_input_stats` must be provided.
-      If `inference_type` is tf.uint8, signaling conversion to a fully quantized
-      model from a quantization-aware trained input model, then
-      `inference_input_type` defaults to tf.uint8. In all other cases,
-      `inference_input_type` defaults to tf.float32. Must be `{tf.float32,
-      tf.uint8, tf.int8}`
-    inference_output_type: Target data type of real-number output arrays. Allows
-      for a different type for output arrays. If `inference_type` is tf.uint8,
-      signaling conversion to a fully quantized model from a quantization-aware
-      trained output model, then `inference_output_type` defaults to tf.uint8.
-      In all other cases, `inference_output_type` must be tf.float32, an error
-      will be thrown otherwise. Must be `{tf.float32, tf.uint8, tf.int8}`
-    output_format: Output file format. Currently must be `{TFLITE,
-      GRAPHVIZ_DOT}`. (default TFLITE)
-    quantized_input_stats: Dict of strings representing input tensor names
-      mapped to tuple of floats representing the mean and standard deviation
-      of the training data (e.g., {"foo" : (0., 1.)}). Only need if
-        `inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
-        (quantized_input_value - mean_value) / std_dev_value. (default {})
-    default_ranges_stats: Tuple of integers representing (min, max) range values
-      for all arrays without a specified range. Intended for experimenting with
-      quantization via "dummy quantization". (default None)
-    drop_control_dependency: Boolean indicating whether to drop control
-      dependencies silently. This is due to TFLite not supporting control
-      dependencies. (default True)
-    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
-      nodes in unexpected locations. Used when the location of the FakeQuant
-      nodes is preventing graph transformations necessary to convert the graph.
-      Results in a graph that differs from the quantized training graph,
-      potentially causing differing arithmetic behavior. (default False)
-    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
-      inputs and outputs of the concat operator for quantized models. Changes
-      the ranges of concat operator overlap when true. (default False)
-    allow_custom_ops: Boolean indicating whether to allow custom operations.
-      When false any unknown operation is an error. When true, custom ops are
-      created for any op that is unknown. The developer will need to provide
-      these to the TensorFlow Lite runtime with a custom resolver. (default
-      False)
-    post_training_quantize: Deprecated. Please specify `[Optimize.DEFAULT]` for
-      `optimizations` instead. Boolean indicating whether to quantize the
-      weights of the converted float model.  Model size will be reduced and
-      there will be latency improvements (at the cost of accuracy). (default
-      False)
-    dump_graphviz_dir: Full filepath of folder to dump the graphs at various
-      stages of processing GraphViz .dot files. Preferred over
-      --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
-      output file. (default None)
-    dump_graphviz_video: Boolean indicating whether to dump the graph after
-      every graph transformation. (default False)
-    conversion_summary_dir: A string indicating the path to the generated
-      conversion logs.
-    target_ops: Deprecated. Please specify `target_spec.supported_ops` instead.
-      Set of OpsSet options indicating which converter to use. (default
-      set([OpsSet.TFLITE_BUILTINS]))
-    target_spec: Experimental flag, subject to change. Specification of target
-      device.
-    optimizations: Experimental flag, subject to change. A list of optimizations
-      to apply when converting the model. E.g. `[Optimize.DEFAULT]`
-    representative_dataset: A representative dataset that can be used to
-      generate input and output samples for the model. The converter can use the
-      dataset to evaluate different optimizations.
-    experimental_new_converter: Experimental flag, subject to change. Enables
-      MLIR-based conversion instead of TOCO conversion. (default True)
-  """
+  """Converter subclass to share functionality between V1 converters."""
 
   def __init__(self, experimental_debug_info_func):
     """Constructor for TFLiteConverter.
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index 714eb24..170db3f 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -85,8 +85,6 @@
 
     # Convert model.
     converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
-    # We don't support integer types as we don't have statistical information
-    # to quantize (only supported for post training integer quantization).
     with self.assertRaises(ValueError) as error:
       converter.inference_input_type = inference_input_output_type
       converter.inference_output_type = inference_input_output_type
@@ -212,8 +210,6 @@
     # Convert quantized model.
     quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
     quantized_converter.optimizations = [lite.Optimize.DEFAULT]
-    # We don't support integer types as we don't have statistical information
-    # to quantize (only supported for post training integer quantization).
     with self.assertRaises(ValueError) as error:
       quantized_converter.inference_input_type = inference_input_output_type
       quantized_converter.inference_output_type = inference_input_output_type
@@ -223,11 +219,20 @@
         'must be tf.float32.', str(error.exception))
 
   @parameterized.named_parameters(
-      ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
-      ('_INT8InputOutput', lite.constants.INT8),
-      ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
-  def testPostTrainingIntegerAllowFloatQuantization(
-      self, inference_input_output_type):
+      ('_Default', False, False, lite.constants.FLOAT),
+      ('_INT8InputOutput', False, False, lite.constants.INT8),
+      ('_UINT8InputOutput', False, False, lite.constants.QUANTIZED_UINT8),
+      ('_INT16Quantize', False, True, lite.constants.FLOAT),
+      ('_INT16Quantize_INT16InputOutput', False, True, lite.constants.INT16),
+      ('_IntOnly', True, False, lite.constants.FLOAT),
+      ('_IntOnly_INT8InputOutput', True, False, lite.constants.INT8),
+      ('_IntOnly_UINT8InputOutput', True, False,
+       lite.constants.QUANTIZED_UINT8),
+      ('_IntOnly_INT16Quantize', True, True, lite.constants.FLOAT),
+      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True,
+       lite.constants.INT16))
+  def testIntegerQuantization(self, is_int_only, is_int16_quantize,
+                              inference_input_output_type):
     func, calibration_gen = self._getIntegerQuantizeModel()
 
     # Convert float model.
@@ -239,111 +244,8 @@
     quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
     quantized_converter.optimizations = [lite.Optimize.DEFAULT]
     quantized_converter.representative_dataset = calibration_gen
-    quantized_converter.inference_input_type = inference_input_output_type
-    quantized_converter.inference_output_type = inference_input_output_type
-    quantized_tflite_model = quantized_converter.convert()
-    self.assertIsNotNone(quantized_tflite_model)
-
-    interpreter = Interpreter(model_content=quantized_tflite_model)
-    interpreter.allocate_tensors()
-    input_details = interpreter.get_input_details()
-    self.assertLen(input_details, 1)
-    self.assertEqual(inference_input_output_type.as_numpy_dtype,
-                     input_details[0]['dtype'])
-    output_details = interpreter.get_output_details()
-    self.assertLen(output_details, 1)
-    self.assertEqual(inference_input_output_type.as_numpy_dtype,
-                     output_details[0]['dtype'])
-
-    # Ensure that the quantized tflite model is smaller.
-    self.assertLess(len(quantized_tflite_model), len(tflite_model))
-
-  def testPostTrainingIntegerAllowFloatQuantizationINT16InputOutput(self):
-    func, calibration_gen = self._getIntegerQuantizeModel()
-
-    # Convert float model.
-    converter = lite.TFLiteConverterV2.from_concrete_functions([func])
-    tflite_model = converter.convert()
-    self.assertTrue(tflite_model)
-
-    # Post-training quantization 16x8 with float fallback allowed.
-    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
-    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
-    quantized_converter.representative_dataset = calibration_gen
-    quantized_converter.target_spec.supported_ops = [
-        lite.OpsSet.\
-        EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
-        lite.OpsSet.TFLITE_BUILTINS
-    ]
-    inference_input_output_type = lite.constants.INT16
-    quantized_converter.inference_input_type = inference_input_output_type
-    quantized_converter.inference_output_type = inference_input_output_type
-    quantized_tflite_model = quantized_converter.convert()
-    self.assertIsNotNone(quantized_tflite_model)
-
-    interpreter = Interpreter(model_content=quantized_tflite_model)
-    interpreter.allocate_tensors()
-    input_details = interpreter.get_input_details()
-    self.assertLen(input_details, 1)
-    self.assertEqual(inference_input_output_type.as_numpy_dtype,
-                     input_details[0]['dtype'])
-    output_details = interpreter.get_output_details()
-    self.assertLen(output_details, 1)
-    self.assertEqual(inference_input_output_type.as_numpy_dtype,
-                     output_details[0]['dtype'])
-
-    # Ensure that the quantized tflite model is smaller.
-    self.assertLess(len(quantized_tflite_model), len(tflite_model))
-
-  def testPostTrainingIntegerQuant16x8MismatchInferenceParams(self):
-    # In this test we check that when we do 16x8 post-training
-    # quantization and set inference_input(output)_type to
-    # constants.INT8, we have an error.
-    func, calibration_gen = self._getIntegerQuantizeModel()
-
-    # Convert quantized model.
-    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
-    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
-    quantized_converter.representative_dataset = calibration_gen
-    quantized_converter.target_spec.supported_ops = [
-        lite.OpsSet.\
-          EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
-        ]
-
-    with self.assertRaises(ValueError) as error:
-      quantized_converter.inference_input_type = lite.constants.INT8
-      quantized_converter.inference_output_type = lite.constants.INT8
-      quantized_converter.convert()
-    self.assertEqual(
-        "The inference_input_type and inference_output_type "
-        "must be in ['tf.float32', 'tf.int16'].", str(error.exception))
-
-  @parameterized.named_parameters(
-      ('_DefaultFLOAT32InputOutput_UseTargetTypesFlag', lite.constants.FLOAT,
-       False, False),
-      ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT, True, False),
-      ('_INT8InputOutput', lite.constants.INT8, True, False),
-      ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8, True, False),
-      ('_INT16InputOutput', lite.constants.INT16, True, True))
-  @test_util.run_v2_only
-  def testPostTrainingIntegerNoFloatQuantization(self,
-                                                 inference_input_output_type,
-                                                 use_target_ops_flag,
-                                                 quantization_16x8):
-    func, calibration_gen = self._getIntegerQuantizeModel()
-
-    # Convert float model.
-    converter = lite.TFLiteConverterV2.from_concrete_functions([func])
-    tflite_model = converter.convert()
-    self.assertTrue(tflite_model)
-
-    # Convert model by specifying target spec (instead of optimizations), since
-    # when targeting an integer only backend, quantization is mandatory.
-    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
-    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
-    quantized_converter.representative_dataset = calibration_gen
-    if use_target_ops_flag:
-      if quantization_16x8:
+    if is_int_only:
+      if is_int16_quantize:
         quantized_converter.target_spec.supported_ops = [
             lite.OpsSet.\
             EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
@@ -353,7 +255,12 @@
             lite.OpsSet.TFLITE_BUILTINS_INT8
         ]
     else:
-      quantized_converter.target_spec.supported_types = [lite.constants.INT8]
+      if is_int16_quantize:
+        quantized_converter.target_spec.supported_ops = [
+            lite.OpsSet.\
+            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
+            lite.OpsSet.TFLITE_BUILTINS
+        ]
     quantized_converter.inference_input_type = inference_input_output_type
     quantized_converter.inference_output_type = inference_input_output_type
     quantized_tflite_model = quantized_converter.convert()
@@ -373,6 +280,30 @@
     # Ensure that the quantized tflite model is smaller.
     self.assertLess(len(quantized_tflite_model), len(tflite_model))
 
+  @parameterized.named_parameters(
+      ('_INT16Quantize_INT8InputOutput', True, lite.constants.INT8))
+  def testInvalidIntegerQuantization(self, is_int16_quantize,
+                                     inference_input_output_type):
+    func, calibration_gen = self._getIntegerQuantizeModel()
+
+    # Convert quantized model.
+    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
+    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
+    quantized_converter.representative_dataset = calibration_gen
+    if is_int16_quantize:
+      quantized_converter.target_spec.supported_ops = [
+          lite.OpsSet.\
+          EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
+          lite.OpsSet.TFLITE_BUILTINS
+      ]
+    with self.assertRaises(ValueError) as error:
+      quantized_converter.inference_input_type = lite.constants.INT8
+      quantized_converter.inference_output_type = lite.constants.INT8
+      quantized_converter.convert()
+    self.assertEqual(
+        "The inference_input_type and inference_output_type "
+        "must be in ['tf.float32', 'tf.int16'].", str(error.exception))
+
   def testCalibrateAndQuantizeBuiltinInt16(self):
     func, calibration_gen = self._getIntegerQuantizeModel()
 
@@ -556,6 +487,36 @@
     converter.convert()
     self._assertValidDebugInfo(converter._debug_info)
 
+  @test_util.run_v2_only
+  def testFlexOpWithInt8OpSet(self):
+    model = tf.keras.Sequential()
+    input_shape = (1, 4, 4, 4, 1)
+    model.add(
+        tf.keras.layers.Conv3D(
+            4,
+            kernel_size=(1, 1, 1),
+            activation='relu',
+            input_shape=input_shape[1:]))
+    model.add(tf.keras.layers.Flatten())
+    model.add(tf.keras.layers.Dense(2, activation='relu'))
+
+    @tf.function(
+        input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)])
+    def _call_fn(inputs):
+      return model(inputs, training=False)
+
+    concrete_func = _call_fn.get_concrete_function(
+        tf.TensorSpec(input_shape, dtype=tf.float32))
+
+    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    converter.target_spec.supported_ops = [
+        tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
+        tf.lite.OpsSet.SELECT_TF_OPS,
+    ]
+    tflite_model = converter.convert()
+    self.assertTrue(tflite_model)
+
 
 class FromSavedModelTest(lite_v2_test_util.ModelTest):
 
diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD
index 0bbb2d5..3f61dc7 100644
--- a/tensorflow/lite/schema/BUILD
+++ b/tensorflow/lite/schema/BUILD
@@ -1,6 +1,7 @@
 load("//tensorflow:tensorflow.bzl", "py_test")
 load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
 load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
+load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
 
 package(
     default_visibility = [
@@ -64,6 +65,7 @@
 flatbuffer_cc_library(
     name = "schema_fbs",
     srcs = ["schema.fbs"],
+    compatible_with = get_compatible_with_portable(),
 )
 
 # Generic schema for flatbuffer converter (but with mutable makes bigger).
diff --git a/tensorflow/lite/string_util.h b/tensorflow/lite/string_util.h
index 0c6ce0b..b8f3fcd 100644
--- a/tensorflow/lite/string_util.h
+++ b/tensorflow/lite/string_util.h
@@ -16,8 +16,9 @@
 // Util methods to read and write String tensors.
 // String tensors are considered to be char tensor with protocol.
 //   [0, 3] 4 bytes: N, num of strings in the tensor in little endian.
-//   [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian.
-//   [(N+2)*4, (N+2)*4+3] 4 bytes: length of the whole char buffer.
+//   [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian,
+//                                 for i from 0 to N-1.
+//   [(N+1)*4, (N+1)*4+3] 4 bytes: length of the whole char buffer.
 //   [offset(i), offset(i+1) - 1] : content of i-th string.
 // Example of a string tensor:
 // [
diff --git a/tensorflow/lite/testing/op_tests/leaky_relu.py b/tensorflow/lite/testing/op_tests/leaky_relu.py
index e37df77..0d2ec38 100644
--- a/tensorflow/lite/testing/op_tests/leaky_relu.py
+++ b/tensorflow/lite/testing/op_tests/leaky_relu.py
@@ -28,12 +28,13 @@
 def make_leaky_relu_tests(options):
   """Make a set of tests to do LeakyRelu."""
 
-  test_parameters = [
-      {
-          "input_shape": [[], [1], [5], [1, 10, 10, 3], [3, 3, 3, 3]],
-          "alpha": [0.1, 1.0, 2.0, -0.1, -1.0, -2.0],
-      },
-  ]
+  test_parameters = [{
+      "input_shape": [[], [1], [5], [1, 10, 10, 3], [3, 3, 3, 3]],
+      "alpha": [0.1, 1.0, 2.0, -0.1, -1.0, -2.0],
+      "fully_quantize": [False, True],
+      "input_range": [(-3, 10)],
+      "quant_16x8": [False, True],
+  }]
 
   def build_graph(parameters):
     """Build the graph for the test case."""
diff --git a/tensorflow/lite/testing/op_tests/range.py b/tensorflow/lite/testing/op_tests/range.py
index ad3d2df..d78742f 100644
--- a/tensorflow/lite/testing/op_tests/range.py
+++ b/tensorflow/lite/testing/op_tests/range.py
@@ -29,7 +29,7 @@
 
   test_parameters = [{
       "dtype": [tf.int32, tf.float32],
-      "offset": [10, 100, 1000],
+      "offset": [10, 100, 1000, 0],
       "delta": [1, 2, 3, 4, -1, -2, -3, -4],
   }]
 
diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD
index 774e7ed..eb3f37a 100644
--- a/tensorflow/lite/tools/benchmark/BUILD
+++ b/tensorflow/lite/tools/benchmark/BUILD
@@ -158,7 +158,6 @@
         "//tensorflow/lite/c:common",
         "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/kernels:cpu_backend_context",
-        "//tensorflow/lite/profiling:platform_profiler",
         "//tensorflow/lite/profiling:profile_summary_formatter",
         "//tensorflow/lite/profiling:profiler",
         "//tensorflow/lite/tools:logging",
diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md
index 453ea5b..df432da 100644
--- a/tensorflow/lite/tools/benchmark/README.md
+++ b/tensorflow/lite/tools/benchmark/README.md
@@ -36,11 +36,6 @@
     mean use no delay.
 *   `enable_op_profiling`: `bool` (default=false) \
     Whether to enable per-operator profiling measurement.
-*   `enable_platform_tracing`: `bool` (default=false) \
-    Whether to enable platform-wide tracing. Needs to be combined with
-    'enable_op_profiling'. Note, the platform-wide tracing might not work if the
-    tool runs as a commandline native binary. For example, on Android, the
-    ATrace-based tracing only works when the tool is launched as an APK.
 *   `profiling_output_csv_file`: `str` (default="") \
     File path to export profile data to as CSV. The results are printed to
     `stdout` if option is not set. Requires `enable_op_profiling` to be `true`
diff --git a/tensorflow/lite/tools/benchmark/android/README.md b/tensorflow/lite/tools/benchmark/android/README.md
index 3475d47..d41090d 100644
--- a/tensorflow/lite/tools/benchmark/android/README.md
+++ b/tensorflow/lite/tools/benchmark/android/README.md
@@ -96,7 +96,13 @@
 (0)-(3) Follow the steps (0)-(3) of [build/install/run](#to-buildinstallrun)
 section.
 
-(4) Set up Quick Settings tile for System Tracing app on your device. Follow the
+(4) Enable platform tracing.
+
+```
+adb shell setprop debug.tflite.trace 1
+```
+
+(5) Set up Quick Settings tile for System Tracing app on your device. Follow the
 [instruction](https://developer.android.com/topic/performance/tracing/on-device#set-up-tile).
 The System Tracing tile will be added to the Quick Settings panel.
 
@@ -105,20 +111,20 @@
 [guide](https://developer.android.com/topic/performance/tracing/on-device#app-menu)
 for more information.
 
-(5) Tap the System Tracing tile, which has the label "Record trace". The tile
+(6) Tap the System Tracing tile, which has the label "Record trace". The tile
 becomes enabled, and a persistent notification appears to notify you that the
 system is now recording a trace.
 
-(6) Run the benchmark with platform tracing enabled.
+(7) Run the benchmark with platform tracing enabled.
 
 ```
 adb shell am start -S \
   -n org.tensorflow.lite.benchmark/.BenchmarkModelActivity \
   --es args '"--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
-  --num_threads=4 --enable_op_profiling=true --enable_platform_tracing=true"'
+  --num_threads=4"'
 ```
 
-(7) Wait until the benchmark finishes. It can be checked from Android log
+(8) Wait until the benchmark finishes. It can be checked from Android log
 messages, e.g.,
 
 ```
@@ -127,14 +133,14 @@
 ... tflite  : Average inference timings in us: Warmup: 91471, Init: 4108, Inference: 80660.1
 ```
 
-(8) Stop tracing by tapping either the System Tracing tile in the Quick Settings
+(9) Stop tracing by tapping either the System Tracing tile in the Quick Settings
 panel or on the System Tracing notification. The system displays a new
 notification that contains the message "Saving trace". When saving is complete,
 the system dismisses the notification and displays a third notification "Trace
 saved", confirming that your trace has been saved and that you're ready to share
 the system trace.
 
-(9)
+(10)
 [Share](https://developer.android.com/topic/performance/tracing/on-device#share-trace)
 a trace file,
 [convert](https://developer.android.com/topic/performance/tracing/on-device#converting_between_trace_formats)
@@ -143,3 +149,9 @@
 an HTML report. Note that, the captured tracing file format is either in
 Perfetto format or in Systrace format depending on the Android version of your
 device. Select the appropriate method to handle the generated file.
+
+(11) Disable platform tracing.
+
+```
+adb shell setprop debug.tflite.trace 0
+```
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index ef9742e..511244c 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -34,7 +34,6 @@
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
 #include "tensorflow/lite/op_resolver.h"
-#include "tensorflow/lite/profiling/platform_profiler.h"
 #include "tensorflow/lite/profiling/profile_summary_formatter.h"
 #include "tensorflow/lite/string_util.h"
 #include "tensorflow/lite/tools/benchmark/benchmark_utils.h"
@@ -61,20 +60,6 @@
 constexpr int kOpProfilingEnabledDefault = false;
 #endif
 
-// Dumps platform-wide tracing files via a platform-based profiler that's built
-// upon platform tracing tools, like ATrace on Android etc.
-class PlatformProfilingListener : public BenchmarkListener {
- public:
-  explicit PlatformProfilingListener(Interpreter* interpreter) {
-    TFLITE_TOOLS_CHECK(interpreter);
-    platform_profiler_ = profiling::CreatePlatformProfiler();
-    interpreter->SetProfiler(platform_profiler_.get());
-  }
-
- private:
-  std::unique_ptr<tflite::Profiler> platform_profiler_;
-};
-
 // Dumps ruy profiling events if the ruy profiler is enabled.
 class RuyProfileListener : public BenchmarkListener {
  public:
@@ -269,8 +254,6 @@
                           BenchmarkParam::Create<int32_t>(1024));
   default_params.AddParam("profiling_output_csv_file",
                           BenchmarkParam::Create<std::string>(""));
-  default_params.AddParam("enable_platform_tracing",
-                          BenchmarkParam::Create<bool>(false));
 
   for (const auto& delegate_provider :
        tools::GetRegisteredDelegateProviders()) {
@@ -331,10 +314,7 @@
       CreateFlag<std::string>(
           "profiling_output_csv_file", &params_,
           "File path to export profile data as CSV, if not set "
-          "prints to stdout."),
-      CreateFlag<bool>("enable_platform_tracing", &params_,
-                       "enable platform-wide tracing, only meaningful when "
-                       "--enable_op_profiling is set to true.")};
+          "prints to stdout.")};
 
   flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
 
@@ -369,8 +349,6 @@
                       "Max profiling buffer entries", verbose);
   LOG_BENCHMARK_PARAM(std::string, "profiling_output_csv_file",
                       "CSV File to export profiling data to", verbose);
-  LOG_BENCHMARK_PARAM(bool, "enable_platform_tracing",
-                      "Enable platform-wide tracing", verbose);
 
   for (const auto& delegate_provider :
        tools::GetRegisteredDelegateProviders()) {
@@ -746,11 +724,6 @@
 BenchmarkTfLiteModel::MayCreateProfilingListener() const {
   if (!params_.Get<bool>("enable_op_profiling")) return nullptr;
 
-  if (params_.Get<bool>("enable_platform_tracing")) {
-    return std::unique_ptr<BenchmarkListener>(
-        new PlatformProfilingListener(interpreter_.get()));
-  }
-
   return std::unique_ptr<BenchmarkListener>(new ProfilingListener(
       interpreter_.get(), params_.Get<int32_t>("max_profiling_buffer_entries"),
       params_.Get<std::string>("profiling_output_csv_file"),
diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
index d320a90..31405df 100644
--- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
+++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
@@ -226,6 +226,17 @@
     }                                                                      \
   } while (0)
 
+#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon)                          \
+  do {                                                                       \
+    auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a));                    \
+    if (delta > epsilon) {                                                   \
+      TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)",       \
+                         __FILE__, __LINE__, #a, #b, static_cast<double>(a), \
+                         static_cast<double>(b));                            \
+      return kTfLiteError;                                                   \
+    }                                                                        \
+  } while (0)
+
 #define TF_LITE_ENSURE_OK(context, status) \
   do {                                     \
     const TfLiteStatus s = (status);       \
diff --git a/tensorflow/lite/tools/cmake/README.md b/tensorflow/lite/tools/cmake/README.md
new file mode 100644
index 0000000..7624b66
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/README.md
@@ -0,0 +1,50 @@
+# Build TensorFlow Lite with CMake
+
+This page describes how to build the TensorFlow Lite static library with CMake
+tool.
+
+The following instructions have been tested on Ubuntu 16.04.3 64-bit PC (AMD64)
+and TensorFlow devel docker image
+[tensorflow/tensorflow:devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
+
+**Note:** This is an experimental that is subject to change.
+
+**Note:** The following are not currently supported: Android, iOS, Tests and
+Host Tools (i.e benchmark / analysis tools etc.)
+
+#### Step 1. Install CMake tool
+
+It requires CMake 3.16 or higher. On Ubunutu, you can simply run the following
+command.
+
+```sh
+sudo apt-get install cmake
+```
+
+Or you can follow [the offcial cmake installation guide](https://cmake.org/install/)
+
+#### Step 2. Clone TensorFlow repository
+
+```sh
+git clone https://github.com/tensorflow/tensorflow.git tensorflow_src
+```
+
+**Note:** If you're using the TensorFlow Docker image, the repo is already
+provided in `/tensorflow_src/`.
+
+#### Step 3. Create CMake build directory and run CMake tool
+
+```sh
+mkdir tflite_build
+cd tflite_build
+cmake ../tensorflow_src/tensorflow/lite
+```
+
+#### Step 4. Build TensorFlow Lite
+
+```sh
+cmake --build . -j
+```
+
+**Note:** This should compile a static library `libtensorflow-lite.a` in the
+current directory.
diff --git a/tensorflow/lite/tools/cmake/modules/Findeigen.cmake b/tensorflow/lite/tools/cmake/modules/Findeigen.cmake
new file mode 100644
index 0000000..1ffb547
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findeigen.cmake
@@ -0,0 +1,24 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tensorflow-lite uses find_package for this package, so override the system
+# installation and build from source instead.
+include(eigen)
+if(eigen_POPULATED)
+  set(EIGEN_FOUND TRUE)
+  get_target_property(EIGEN_INCLUDE_DIRS eigen INTERFACE_DIRECTORIES)
+  set(EIGEN_LIBRARIES Eigen3::Eigen)
+endif()
+
diff --git a/tensorflow/lite/tools/cmake/modules/Findfarmhash.cmake b/tensorflow/lite/tools/cmake/modules/Findfarmhash.cmake
new file mode 100644
index 0000000..1b0dc28
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findfarmhash.cmake
@@ -0,0 +1,25 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tensorflow-lite uses find_package for this package, so override the system
+# installation and build from source instead.
+include(farmhash)
+if(farmhash_POPULATED)
+  set(FARMHASH_FOUND TRUE)
+  get_target_property(FARMHASH_INCLUDE_DIRS farmhash INTERFACE_DIRECTORIES)
+  add_library(farmhash::farmhash ALIAS farmhash)
+  set(FARMHASH_LIBRARIES farmhash::farmhash)
+endif()
+
diff --git a/tensorflow/lite/tools/cmake/modules/Findfft2d.cmake b/tensorflow/lite/tools/cmake/modules/Findfft2d.cmake
new file mode 100644
index 0000000..0d07432
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findfft2d.cmake
@@ -0,0 +1,37 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tensorflow-lite uses find_package for this package, so override the system
+# installation and build from source instead.
+include(fft2d)
+if(fft2d_POPULATED)
+  set(FFT2D_FOUND TRUE CACHE BOOL "Found FF2D")
+  get_target_property(FFT2D_INCLUDE_DIRS fft2d INCLUDE_DIRECTORIES)
+  set(FFT2D_INCLUDE_DIRS ${FFT2D_INCLUDE_DIRS} CACHE STRING
+    "FFT2D include dirs"
+  )
+  set(FFT2D_LIBRARIES
+    fft2d_alloc
+    fft2d_fft4f2d
+    fft2d_fftsg
+    fft2d_fftsg2d
+    fft2d_fftsg3d
+    fft2d_shrtdct
+    CACHE
+    STRING
+    "FFT2D libraries"
+  )
+endif()
+
diff --git a/tensorflow/lite/tools/cmake/modules/Findflatbuffers.cmake b/tensorflow/lite/tools/cmake/modules/Findflatbuffers.cmake
new file mode 100644
index 0000000..feb447b
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findflatbuffers.cmake
@@ -0,0 +1,27 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tensorflow-lite uses find_package for this package, so override the system
+# installation and build from source instead.
+include(flatbuffers)
+if(flatbuffers_POPULATED)
+  set(FLATBUFFERS_FOUND TRUE)
+  get_target_property(FLATBUFFERS_INCLUDE_DIRS flatbuffers INCLUDE_DIRECTORIES)
+  set(FLATBUFFERS_LIBRARIES flatbuffers)
+  set(FLATBUFFERS_PROJECT_DIR "${flatbuffers_SOURCE_DIR}" CACHE STRING
+    "Flatbuffers project dir"
+  )
+endif()
+
diff --git a/tensorflow/lite/tools/cmake/modules/Findgemmlowp.cmake b/tensorflow/lite/tools/cmake/modules/Findgemmlowp.cmake
new file mode 100644
index 0000000..70331ad
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findgemmlowp.cmake
@@ -0,0 +1,29 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tensorflow-lite uses find_package for this package, so override the system
+# installation and build from source instead.
+include(gemmlowp)
+if(gemmlowp_POPULATED)
+  set(GEMMLOWP_FOUND TRUE)
+  get_target_property(GEMMLOWP_INCLUDE_DIRS gemmlowp INTERFACE_DIRECTORIES)
+  set(GEMMLOWP_LIBRARIES
+    gemmlowp
+    gemmlowp_fixedpoint
+    gemmlowp_profiler
+    gemmlowp_eight_bit_int_gemm
+  )
+endif()
+
diff --git a/tensorflow/lite/tools/cmake/modules/Findneon2sse.cmake b/tensorflow/lite/tools/cmake/modules/Findneon2sse.cmake
new file mode 100644
index 0000000..8354385
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findneon2sse.cmake
@@ -0,0 +1,23 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tensorflow-lite uses find_package for this package, so override the system
+# installation and build from source instead.
+include(neon2sse)
+if(neon2sse_POPULATED)
+  set(NEON2SSE_FOUND TRUE)
+  get_target_property(NEON2SSE_INCLUDE_DIRS NEON_2_SSE INTERFACE_DIRECTORIES)
+  set(NEON2SSE_LIBRARIES NEON_2_SSE)
+endif()
diff --git a/tensorflow/lite/tools/cmake/modules/Findruy.cmake b/tensorflow/lite/tools/cmake/modules/Findruy.cmake
new file mode 100644
index 0000000..e1517ee
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/Findruy.cmake
@@ -0,0 +1,16 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+include(ruy)
diff --git a/tensorflow/lite/tools/cmake/modules/OverridableFetchContent.cmake b/tensorflow/lite/tools/cmake/modules/OverridableFetchContent.cmake
new file mode 100644
index 0000000..9ed9510
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/OverridableFetchContent.cmake
@@ -0,0 +1,583 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+include(FetchContent)
+
+# Pairs of regex --> replacement strings that map Git repositories to archive
+# URLs. GIT_COMMIT is replaced with the hash of the commit.
+set(OVERRIDABLE_FETCH_CONTENT_GITHUB_MATCH
+  "^https?://github.com/([^/]+)/([^/.]+)(\\.git)?\$"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GITHUB_REPLACE
+  "https://github.com/\\1/\\2/archive/GIT_COMMIT.zip"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GITLAB_MATCH
+  "^https?://gitlab.com/([^/]+)/([^/.]+)(\\.git)?"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GITLAB_REPLACE
+  "https://gitlab.com/\\1/\\2/-/archive/GIT_COMMIT/\\2-GIT_COMMIT.tar.gz"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GOOGLESOURCE_MATCH
+  "^(https?://[^.]+\\.googlesource\\.com/.*)"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GOOGLESOURCE_REPLACE
+  "\\1/+archive/GIT_COMMIT.tar.gz"
+)
+# List of prefixes for regex match and replacement variables that map Git
+# repositories to archive URLs.
+list(APPEND OVERRIDABLE_FETCH_CONTENT_GIT_TRANSFORMS
+  OVERRIDABLE_FETCH_CONTENT_GITHUB
+  OVERRIDABLE_FETCH_CONTENT_GITLAB
+  OVERRIDABLE_FETCH_CONTENT_GOOGLESOURCE
+)
+
+# Pairs of regex --> replacement strings that map Git repositories to raw file
+# URLs.
+set(OVERRIDABLE_FETCH_CONTENT_GITHUB_FILE_MATCH
+  "${OVERRIDABLE_FETCH_CONTENT_GITHUB_MATCH}"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GITHUB_FILE_REPLACE
+  "https://raw.githubusercontent.com/\\1/\\2/GIT_COMMIT/FILE_PATH"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GITLAB_FILE_MATCH
+  "${OVERRIDABLE_FETCH_CONTENT_GITLAB_MATCH}"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GITLAB_FILE_REPLACE
+  "https://gitlab.com/\\1/\\2/-/raw/GIT_COMMIT/FILE_PATH"
+)
+set(OVERRIDABLE_FETCH_CONTENT_GOOGLESOURCE_FILE_MATCH
+  "${OVERRIDABLE_FETCH_CONTENT_GOOGLESOURCE_MATCH}"
+)
+# This isn't the raw file, gitiles doesn't support raw file download without
+# decoding the file from base64.
+set(OVERRIDABLE_FETCH_CONTENT_GOOGLESOURCE_FILE_REPLACE
+  "\\1/+/GIT_COMMIT/FILE_PATH"
+)
+
+# List of prefixes for regex match and replacement variables that map Git
+# repositories to archive URLs.
+list(APPEND OVERRIDABLE_FETCH_CONTENT_GIT_FILE_TRANSFORMS
+  OVERRIDABLE_FETCH_CONTENT_GITHUB_FILE
+  OVERRIDABLE_FETCH_CONTENT_GITLAB_FILE
+  OVERRIDABLE_FETCH_CONTENT_GOOGLESOURCE_FILE
+)
+
+# Try applying replacements to string.
+#
+# TRANSFORMS: List which contains prefixes for  _MATCH / _REPLACE replacements
+# to try. For example, given the list "FOO" this will try to apply a regex
+# replacement with the value of FOO_MATCH and FOO_REPLACE.
+# TO_REPLACE: String to apply replacements to.
+# OUTPUT_VAR: Name of the variable to store the URL if successful. If
+# conversion fails this variable will be empty.
+function(_ApplyReplacements TRANSFORMS TO_REPLACE OUTPUT_VAR)
+  foreach(PREFIX ${TRANSFORMS})
+    message(VERBOSE "Try converting ${GIT_REPOSITORY} with ${${PREFIX}_MATCH}")
+    set(MATCH "${${PREFIX}_MATCH}")
+    set(REPLACE "${${PREFIX}_REPLACE}")
+    if(MATCH AND REPLACE)
+      string(REGEX REPLACE
+        "${MATCH}"
+        "${REPLACE}"
+        REPLACED
+        "${TO_REPLACE}"
+      )
+      if(NOT "${REPLACED}" STREQUAL "${TO_REPLACE}")
+        set(${OUTPUT_VAR} "${REPLACED}" PARENT_SCOPE)
+      endif()
+    endif()
+  endforeach()
+endfunction()
+
+
+# Try to convert a Git repository to an archive URL.
+#
+# GIT_REPOSITORY: Repository URL to convert.
+# GIT_COMMIT: Commit hash or tag to convert.
+# REPORT_WARNING: Whether to report a warning if conversion fails.
+# OUTPUT_VAR: Name of the variable to store the URL if successful. If
+# conversion fails this variable will be empty.
+function(_GitRepoArchiveUrl GIT_REPOSITORY GIT_COMMIT REPORT_WARNING OUTPUT_VAR)
+  list(REMOVE_DUPLICATES OVERRIDABLE_FETCH_CONTENT_GIT_TRANSFORMS)
+  _ApplyReplacements(
+    "${OVERRIDABLE_FETCH_CONTENT_GIT_TRANSFORMS}"
+    "${GIT_REPOSITORY}"
+    REPLACED
+  )
+  if(REPLACED)
+    string(REPLACE "GIT_COMMIT" "${GIT_COMMIT}" WITH_COMMIT "${REPLACED}")
+    message(VERBOSE "${GIT_REPOSITORY} / ${GIT_COMMIT} --> ${WITH_COMMIT}")
+    set(${OUTPUT_VAR} "${WITH_COMMIT}" PARENT_SCOPE)
+  elseif(REPORT_WARNING)
+    message(WARNING
+      "Unable to map ${GIT_REPOSITORY} / ${GIT_COMMIT} to an archive URL"
+    )
+  endif()
+endfunction()
+
+
+# Try to convert a Git repository, commit and relative path to a link to the
+# file.
+#
+# GIT_REPOSITORY: Repository URL to convert.
+# GIT_COMMIT: Commit hash or tag to convert.
+# FILE_PATH: Path to the file.
+# OUTPUT_VAR: Name of the variable to store the URL if successful. If
+# conversion fails this variable will be empty.
+function(_GitRepoFileUrl GIT_REPOSITORY GIT_COMMIT FILE_PATH OUTPUT_VAR)
+  list(REMOVE_DUPLICATES OVERRIDABLE_FETCH_CONTENT_GIT_FILE_TRANSFORMS)
+  _ApplyReplacements(
+    "${OVERRIDABLE_FETCH_CONTENT_GIT_FILE_TRANSFORMS}"
+    "${GIT_REPOSITORY}"
+    REPLACED
+  )
+  if(REPLACED)
+    string(REPLACE "GIT_COMMIT" "${GIT_COMMIT}" WITH_COMMIT "${REPLACED}")
+    string(REPLACE "FILE_PATH" "${FILE_PATH}" WITH_FILE "${WITH_COMMIT}")
+    message(VERBOSE
+      "${GIT_REPOSITORY} / ${GIT_COMMIT} / ${FILE_PATH} --> ${WITH_FILE}"
+    )
+    set(${OUTPUT_VAR} "${WITH_FILE}" PARENT_SCOPE)
+  else()
+    message(WARNING
+      "Unable to map ${GIT_REPOSITORY} / ${GIT_COMMIT} / ${FILE_PATH} to a URL"
+    )
+  endif()
+endfunction()
+
+
+# Try to determine the license URL from a path within the content and
+# cache LICENSE_FILE and LICENSE_URL properties.
+#
+# CONTENT_NAME: Name of the content that hosts the license.
+# LICENSE_FILE: Relative path in the archive.
+# OUTPUT_VAR: Name of variable to store / retrieve the license URL.
+function(_LicenseFileToUrl CONTENT_NAME LICENSE_FILE OUTPUT_VAR)
+  foreach(PROPERTY GIT_REPOSITORY GIT_COMMIT LICENSE_URL)
+    _OverridableFetchContent_GetProperty(
+      "${CONTENT_NAME}"
+      "${PROPERTY}"
+      "${PROPERTY}"
+    )
+  endforeach()
+  _OverridableFetchContent_SetProperty(
+    "${CONTENT_NAME}"
+    LICENSE_FILE
+    "License for ${CONTENT_NAME}"
+    "${LICENSE_FILE}"
+  )
+  if(NOT LICENSE_URL)
+    if(GIT_REPOSITORY AND GIT_COMMIT)
+      # Try to synthesize the license URL from the repo path.
+      _GitRepoFileUrl(
+        "${GIT_REPOSITORY}"
+        "${GIT_COMMIT}"
+        "${LICENSE_FILE}"
+        LICENSE_URL
+      )
+      if(LICENSE_URL)
+        _OverridableFetchContent_SetProperty(
+          "${CONTENT_NAME}"
+          LICENSE_URL
+          "License URL for ${CONTENT_NAME}"
+          "${LICENSE_URL}"
+        )
+        set(${OUTPUT_VAR} "${LICENSE_URL}" PARENT_SCOPE)
+      endif()
+    endif()
+  endif()
+endfunction()
+
+
+# Replacement for FetchContent_Declare() that allows the user to override the
+# download URL for Git and URL sources and also favor fetching via URL vs.
+# a Git repo using variables external to this method.
+#
+# See FetchContent_Declare() and ExternalProject_Add() for the arguments
+# supported by this method.
+#
+# In addition to FetchContent_Declare() and ExternalProject_Add() arguments,
+# this method supports LICENSE_FILE that enables the caller to specify the
+# relative path of the license in the downloaded archive which disables the
+# search for a license in OverridableFetchContent_Populate().
+# LICENSE_URL can be specified to override the URL of the LICENSE_FILE if
+# a direct link to the URL can't be formed from the download path.
+#
+# It's possible to override, GIT_REPOSITORY, GIT_TAG, URL and URL_HASH for
+# a target by setting
+# OVERRIDABLE_FETCH_CONTENT_<contentName>_<variable> where <contentName> is the
+# CONTENT_NAME argument content provided to this method and <variable> is the
+# argument of this method to override. For example, given CONTENT_NAME = foo
+# the GIT_REPOSITORY can be overridden by setting foo_GIT_REPOSITORY to the
+# value to use instead.
+#
+# To convert a GIT_REPOSITORY / GIT_TAG reference to a URL,
+# set OVERRIDABLE_FETCH_CONTENT_GIT_REPOSITORY_AND_TAG_TO_URL_<contentName>
+# to ON for one repository or
+# OVERRIDABLE_FETCH_CONTENT_GIT_REPOSITORY_AND_TAG_TO_URL to ON for all
+# repositories. This will, where possible, convert a GIT_REPOSITORY / GIT_TAG
+# reference to a URL to download instead which is much faster to copy than
+# cloning a git repo.
+#
+# If OVERRIDABLE_FETCH_CONTENT_USE_GIT is ON, when a GIT_REPOSITORY and a
+# download URL are specified this method will clone the GIT_REPOSITORY. When
+# OVERRIDABLE_FETCH_CONTENT_USE_GIT is OFF or not set and both GIT_REPOSITORY
+# and download URL are specified the download URL is used instead.
+#
+# To override the archive URL before it's passed to FetchContent_Declare()
+# set OVERRIDABLE_FETCH_CONTENT_<contentName>_MATCH to a regular expression
+# to match the archive URL and OVERRIDABLE_FETCH_CONTENT_<contentName>_REPLACE
+# with the string to replace the archive URL.
+#
+# All content names passed to this method are added to the global property
+# OVERRIDABLE_FETCH_CONTENT_LIST.
+function(OverridableFetchContent_Declare CONTENT_NAME)
+  set(OVERRIDABLE_ARGS
+    GIT_REPOSITORY
+    GIT_TAG
+    URL
+    URL_HASH
+    URL_MD5
+  )
+  set(ALL_VALUE_ARGS LICENSE_FILE LICENSE_URL ${OVERRIDABLE_ARGS})
+  cmake_parse_arguments(ARGS
+    ""
+    "${ALL_VALUE_ARGS}"
+    ""
+    ${ARGN}
+  )
+  # Optionally override parsed arguments with values from variables in the form
+  # ${CONTENT_NAME}_${OVERRIDABLE_ARG}.
+  foreach(OVERRIDABLE_ARG in ${OVERRIDABLE_ARGS})
+    set(OVERRIDE_VALUE
+      ${OVERRIDABLE_FETCH_CONTENT_${CONTENT_NAME}_${OVERRIDABLE_ARG}}
+    )
+    if(NOT "${OVERRIDE_VALUE}" STREQUAL "")
+      set(ARGS_${OVERRIDABLE_ARG} "${OVERRIDE_VALUE}")
+      message(VERBOSE "Overriding ${OVERRIDABLE_ARG} of content "
+        "${CONTENT_NAME} with '${OVERRIDE_VALUE}'"
+      )
+    endif()
+  endforeach()
+
+  # If specified, save the source URL so it's possible to synthesize a link to
+  # the license when the content is populated.
+  if(ARGS_GIT_REPOSITORY AND ARGS_GIT_TAG)
+    _OverridableFetchContent_SetProperty(
+      "${CONTENT_NAME}"
+      GIT_REPOSITORY
+      "Git repo for ${CONTENT_NAME}"
+      "${ARGS_GIT_REPOSITORY}"
+    )
+    _OverridableFetchContent_SetProperty(
+      "${CONTENT_NAME}"
+      GIT_COMMIT
+      "Git commit for ${CONTENT_NAME}"
+      "${ARGS_GIT_TAG}"
+    )
+  endif()
+
+  # Set the license file and URL properties.
+  if(ARGS_LICENSE_URL)
+    _OverridableFetchContent_SetProperty(
+      "${CONTENT_NAME}"
+      LICENSE_URL
+      "License URL for ${CONTENT_NAME}"
+      "${ARGS_LICENSE_URL}"
+    )
+  endif()
+  if(ARGS_LICENSE_FILE)
+    _LicenseFileToUrl(
+      "${CONTENT_NAME}"
+      "${ARGS_LICENSE_FILE}"
+      ARGS_LICENSE_URL
+    )
+  endif()
+
+  # Try mapping to an archive URL.
+  set(ARCHIVE_URL "")
+  if(ARGS_GIT_REPOSITORY AND ARGS_GIT_TAG)
+    _GitRepoArchiveUrl(
+      "${ARGS_GIT_REPOSITORY}"
+      "${ARGS_GIT_TAG}"
+      OFF
+      ARCHIVE_URL
+    )
+    # If conversion from git repository to archive URL is enabled.
+    if(OVERRIDABLE_FETCH_CONTENT_GIT_REPOSITORY_AND_TAG_TO_URL_${CONTENT_NAME}
+       OR OVERRIDABLE_FETCH_CONTENT_GIT_REPOSITORY_AND_TAG_TO_URL)
+      # Try converting to an archive URL.
+      if(NOT ARGS_URL)
+        _GitRepoArchiveUrl(
+          "${ARGS_GIT_REPOSITORY}"
+          "${ARGS_GIT_TAG}"
+          ON
+          ARGS_URL
+        )
+        set(ARCHIVE_URL "${ARGS_URL}")
+      endif()
+    endif()
+  endif()
+
+  # If a download URL and git repository with tag are specified either use
+  # the git repo or the download URL.
+  if(ARGS_URL AND ARGS_GIT_REPOSITORY)
+    if(OVERRIDABLE_FETCH_CONTENT_USE_GIT)
+      unset(ARGS_URL)
+      unset(ARGS_URL_HASH)
+      unset(ARGS_URL_MD5)
+    else()
+      unset(ARGS_GIT_REPOSITORY)
+      unset(ARGS_GIT_TAG)
+    endif()
+  endif()
+
+  # Optionally map the archive URL to a mirror.
+  if(ARGS_URL)
+    _ApplyReplacements(
+      "OVERRIDABLE_FETCH_CONTENT_${CONTENT_NAME}"
+      "${ARGS_URL}"
+      REPLACED
+    )
+    if(REPLACED)
+      set(ARGS_URL "${REPLACED}")
+    endif()
+  endif()
+
+  # Save the archive URL.
+  if(ARGS_URL)
+    set(ARCHIVE_URL "${ARGS_URL}")
+  endif()
+  if(ARCHIVE_URL)
+    _OverridableFetchContent_SetProperty(
+      "${CONTENT_NAME}"
+      ARCHIVE_URL
+      "Archive URL for ${CONTENT_NAME}"
+      "${ARCHIVE_URL}"
+    )
+  endif()
+
+  # Build the list of arguments to pass to FetchContent_Declare() starting with
+  # the overridable arguments.
+  set(OUTPUT_ARGS "")
+  foreach(OVERRIDABLE_ARG ${OVERRIDABLE_ARGS})
+    set(OVERRIDABLE_ARG_VALUE "${ARGS_${OVERRIDABLE_ARG}}")
+    if(OVERRIDABLE_ARG_VALUE)
+      list(APPEND OUTPUT_ARGS ${OVERRIDABLE_ARG} "${OVERRIDABLE_ARG_VALUE}")
+    endif()
+  endforeach()
+  list(APPEND OUTPUT_ARGS ${ARGS_UNPARSED_ARGUMENTS})
+
+  # Add all defined packages to a global property.
+  get_property(OVERRIDABLE_FETCH_CONTENT_LIST GLOBAL PROPERTY
+    OVERRIDABLE_FETCH_CONTENT_LIST
+  )
+  set(DOCUMENTATION "List of all fetched content")
+  define_property(GLOBAL PROPERTY OVERRIDABLE_FETCH_CONTENT_LIST
+    BRIEF_DOCS "${DOCUMENTATION}"
+    FULL_DOCS "${DOCUMENTATION}"
+  )
+  list(APPEND OVERRIDABLE_FETCH_CONTENT_LIST "${CONTENT_NAME}")
+  set_property(GLOBAL PROPERTY OVERRIDABLE_FETCH_CONTENT_LIST
+    "${OVERRIDABLE_FETCH_CONTENT_LIST}"
+  )
+
+  message(VERBOSE "FetchContent_Declare(${CONTENT_NAME} ${OUTPUT_ARGS}")
+  FetchContent_Declare("${CONTENT_NAME}" ${OUTPUT_ARGS})
+endfunction()
+
+
+# Get a property name for this module.
+# CONTENT_NAME: Name of the content associated with the FetchContent function.
+# PROPERTY_NAME: Name of the property.
+# OUTPUT_VAR: Variable to store the name in.
+function(_OverridableFetchContent_GetPropertyName CONTENT_NAME PROPERTY_NAME
+    OUTPUT_VAR)
+  # The implementation of FetchContent_GetProperties() uses the lower case
+  # content name to prefix property names so follow the same pattern here.
+  string(TOLOWER "${CONTENT_NAME}" CONTENT_NAME_LOWER)
+  set(${OUTPUT_VAR}
+    "_OverridableFetchContent_${CONTENT_NAME_LOWER}_${PROPERTY_NAME}"
+    PARENT_SCOPE
+  )
+endfunction()
+
+
+# Set a global property for this module.
+# CONTENT_NAME: Name of the content associated with the FetchContent function.
+# PROPERTY_NAME: Name of the property to set.
+# DOCUMENTATION: Documentation string for the property.
+# PROPERTY_VALUE: Value to set the property to.
+function(_OverridableFetchContent_SetProperty CONTENT_NAME PROPERTY_NAME
+    DOCUMENTATION PROPERTY_VALUE)
+  _OverridableFetchContent_GetPropertyName(
+    "${CONTENT_NAME}"
+    "${PROPERTY_NAME}"
+    GLOBAL_PROPERTY_NAME
+  )
+  define_property(GLOBAL PROPERTY "${GLOBAL_PROPERTY_NAME}"
+    BRIEF_DOCS "${DOCUMENTATION}"
+    FULL_DOCS "${DOCUMENTATION}"
+  )
+  set_property(GLOBAL PROPERTY "${GLOBAL_PROPERTY_NAME}" "${PROPERTY_VALUE}")
+endfunction()
+
+
+# Get a global property for this module.
+# CONTENT_NAME: Name of the content associated with the FetchContent function.
+# PROPERTY_NAME: Name of the property to get.
+# OUTPUT_VAR: Variable to store the value in.
+function(_OverridableFetchContent_GetProperty CONTENT_NAME PROPERTY_NAME
+    OUTPUT_VAR)
+  _OverridableFetchContent_GetPropertyName(
+    "${CONTENT_NAME}"
+    "${PROPERTY_NAME}"
+    GLOBAL_PROPERTY_NAME
+  )
+  get_property(VALUE GLOBAL PROPERTY "${GLOBAL_PROPERTY_NAME}")
+  if(VALUE)
+    set(${OUTPUT_VAR} "${VALUE}" PARENT_SCOPE)
+  endif()
+endfunction()
+
+
+# Export a list of variables to the parent scope of the caller function.
+macro(_OverridableFetchContent_ExportToParentScope)
+  # Export requested variables to the parent scope.
+  foreach(VARIABLE_NAME ${ARGN})
+    if(${VARIABLE_NAME})
+      message(DEBUG "Export ${VARIABLE_NAME} ${${VARIABLE_NAME}}")
+      set(${VARIABLE_NAME} "${${VARIABLE_NAME}}" PARENT_SCOPE)
+    endif()
+  endforeach()
+endmacro()
+
+
+# Wrapper around FetchContent_GetProperties().
+#
+# Sets the same variables as FetchContent_GetProperties() in addition to:
+# * <contentName>_LICENSE_FILE: License file relative to
+#   <contentName>_SOURCE_DIR if found.
+# * <contentName>_LICENSE_URL: License URL if the file is found.
+# * <contentName_ARCHIVE_URL: URL to the source package.
+function(OverridableFetchContent_GetProperties CONTENT_NAME)
+  set(EXPORT_VARIABLE_ARGS SOURCE_DIR BINARY_DIR POPULATED)
+  cmake_parse_arguments(ARGS
+    ""
+    "${EXPORT_VARIABLE_ARGS}"
+    ""
+    ${ARGN}
+  )
+
+  # The implementation of FetchContent_Populate() uses the lower case
+  # content name to prefix returned variable names.
+  string(TOLOWER "${CONTENT_NAME}" CONTENT_NAME_LOWER)
+  # Get the names of the variables to export to the parent scope.
+  set(EXPORT_VARIABLES "")
+  set(OUTPUT_ARGS "")
+  foreach(ARG_NAME ${EXPORT_VARIABLE_ARGS})
+    set(ARG_VARIABLE_NAME "ARGS_${ARG_NAME}")
+    set(ARG_VARIABLE_VALUE "${${ARG_VARIABLE_NAME}}")
+    list(APPEND EXPORT_VARIABLES "${CONTENT_NAME_LOWER}_${ARG_NAME}")
+    if(ARG_VARIABLE_VALUE)
+      list(APPEND EXPORT_VARIABLES "${ARG_VARIABLE_VALUE}")
+      list(APPEND OUTPUT_ARGS "${ARG_NAME}" "${ARG_VARIABLE_VALUE}")
+    endif()
+  endforeach()
+  list(APPEND OUTPUT_ARGS ${ARGS_UNPARSED_ARGUMENTS})
+
+  foreach(EXPORT_PROPERTY LICENSE_FILE LICENSE_URL ARCHIVE_URL)
+    _OverridableFetchContent_GetProperty("${CONTENT_NAME}"
+      "${EXPORT_PROPERTY}"
+      "${EXPORT_PROPERTY}"
+    )
+    set(PROPERTY_VALUE "${${EXPORT_PROPERTY}}")
+    if(PROPERTY_VALUE)
+      set(${CONTENT_NAME}_${EXPORT_PROPERTY} "${PROPERTY_VALUE}" PARENT_SCOPE)
+    endif()
+  endforeach()
+  FetchContent_GetProperties("${CONTENT_NAME}" ${OUTPUT_ARGS})
+  _OverridableFetchContent_ExportToParentScope(${EXPORT_VARIABLES})
+endfunction()
+
+
+# Replacement for FetchContent_Populate() that searches a newly cloned
+# repository for a top level license file and provides it to the caller
+# via the <contentName>_LICENSE_FILE and <contentName>_LICENSE_URL variables
+# where <contentName> is the value passed as the CONTENT_NAME argument of this
+# method.
+#
+# To ensure a fetched repo has a license file and URL
+# OVERRIDABLE_FETCH_CONTENT_LICENSE_CHECK_<contentName> to ON for one
+# repository or OVERRIDABLE_FETCH_CONTENT_LICENSE_CHECK to ON for all
+# repositories.
+function(OverridableFetchContent_Populate CONTENT_NAME)
+  # The implementation of FetchContent_Populate() uses the lower case
+  # content name to prefix returned variable names.
+  string(TOLOWER "${CONTENT_NAME}" CONTENT_NAME_LOWER)
+
+  FetchContent_Populate("${CONTENT_NAME}")
+  OverridableFetchContent_GetProperties("${CONTENT_NAME}")
+
+  # If a license file isn't cached try finding it in the repo.
+  set(LICENSE_FILE "${${CONTENT_NAME_LOWER}_LICENSE_FILE}")
+  set(LICENSE_URL "${${CONTENT_NAME_LOWER}_LICENSE_URL}")
+  if(${CONTENT_NAME}_POPULATED AND NOT LICENSE_FILE)
+    set(SOURCE_DIR "${${CONTENT_NAME_LOWER}_SOURCE_DIR}")
+    find_file(_${CONTENT_NAME_LOWER}_LICENSE_FILE
+      NAMES LICENSE LICENSE.md LICENSE.txt NOTICE COPYING
+      PATHS "${SOURCE_DIR}"
+      DOC "${CONTENT_NAME} license file"
+      NO_DEFAULT_PATH
+      NO_CMAKE_FIND_ROOT_PATH
+    )
+    set(LICENSE_FILE "${_${CONTENT_NAME_LOWER}_LICENSE_FILE}")
+    if(LICENSE_FILE)
+      file(RELATIVE_PATH LICENSE_FILE "${SOURCE_DIR}" "${LICENSE_FILE}")
+      file(TO_CMAKE_PATH "${LICENSE_FILE}" LICENSE_FILE)
+    endif()
+  endif()
+  # If a LICENSE_FILE was found populate the URL.
+  if(LICENSE_FILE AND NOT LICENSE_URL)
+    _LicenseFileToUrl(
+      "${CONTENT_NAME}"
+      "${LICENSE_FILE}"
+      LICENSE_URL
+    )
+  endif()
+
+  # If enabled, check for source licenses.
+  if(OVERRIDABLE_FETCH_CONTENT_LICENSE_CHECK OR
+     OVERRIDABLE_FETCH_CONTENT_LICENSE_CHECK_${CONTENT_NAME})
+    message(DEBUG "LICENSE_FILE: ${LICENSE_FILE}, LICENSE_URL: ${LICENSE_URL}")
+    if(NOT LICENSE_FILE)
+      message(FATAL_ERROR
+        "Required license file not found for ${CONTENT_NAME}"
+      )
+    endif()
+    if(NOT LICENSE_URL)
+      message(FATAL_ERROR
+        "Required license URL not found for ${CONTENT_NAME}"
+      )
+    endif()
+  endif()
+
+  # Export return values to the parent scope.
+  set(EXPORT_VARIABLES "")
+  foreach(VARIABLE_POSTFIX SOURCE_DIR BINARY_DIR POPULATED)
+    list(APPEND EXPORT_VARIABLES "${CONTENT_NAME_LOWER}_${VARIABLE_POSTFIX}")
+  endforeach()
+  _OverridableFetchContent_ExportToParentScope(${EXPORT_VARIABLES})
+endfunction()
diff --git a/tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake b/tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake
new file mode 100644
index 0000000..5f362f4
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake
@@ -0,0 +1,44 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Use absl_base as a proxy for the project being included.
+if(TARGET absl_base OR abseil-cpp_POPULATED)
+  return()
+endif()
+
+include(OverridableFetchContent)
+
+OverridableFetchContent_Declare(
+  abseil-cpp
+  GIT_REPOSITORY https://github.com/abseil/abseil-cpp
+  GIT_TAG 20200225.2 # TODO: What version does GRPC and TFLite need?
+  GIT_SHALLOW TRUE
+  GIT_PROGRESS TRUE
+  PREFIX "${CMAKE_BINARY_DIR}"
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/abseil-cpp"
+)
+OverridableFetchContent_GetProperties(abseil-cpp)
+if(NOT abseil-cpp_POPULATED)
+  OverridableFetchContent_Populate(abseil-cpp)
+endif()
+
+set(ABSL_USE_GOOGLETEST_HEAD OFF CACHE BOOL "Disable googletest")
+set(ABSL_RUN_TESTS OFF CACHE BOOL "Disable build of ABSL tests")
+add_subdirectory(
+  "${abseil-cpp_SOURCE_DIR}"
+  "${abseil-cpp_BINARY_DIR}"
+  EXCLUDE_FROM_ALL
+)
+
diff --git a/tensorflow/lite/tools/cmake/modules/absl-config.cmake b/tensorflow/lite/tools/cmake/modules/absl-config.cmake
new file mode 100644
index 0000000..7504174
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/absl-config.cmake
@@ -0,0 +1,187 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# grpc uses find_package in CONFIG mode for this package, so override the
+# system installation and build from source instead.
+include(abseil-cpp)
+if(abseil-cpp_POPULATED)
+  set(_ABSL_LIBRARY_NAMES
+    algorithm
+    algorithm_container
+    any
+    atomic_hook
+    atomic_hook_test_helper
+    awesome
+    bad_any_cast
+    bad_any_cast_impl
+    bad_optional_access
+    bad_variant_access
+    base
+    base_internal
+    bind_front
+    bits
+    btree
+    btree_test_common
+    city
+    civil_time
+    compare
+    compressed_tuple
+    config
+    conformance_testing
+    container
+    container_common
+    container_memory
+    cord
+    cord_test_helpers
+    core_headers
+    counting_allocator
+    debugging
+    debugging_internal
+    demangle_internal
+    dynamic_annotations
+    endian
+    errno_saver
+    examine_stack
+    exception_safety_testing
+    exception_testing
+    exponential_biased
+    failure_signal_handler
+    fantastic_lib
+    fast_type_id
+    fixed_array
+    flags
+    flags_commandlineflag
+    flags_commandlineflag_internal
+    flags_config
+    flags_internal
+    flags_marshalling
+    flags_parse
+    flags_path_util
+    flags_private_handle_accessor
+    flags_program_name
+    flags_reflection
+    flags_usage
+    flags_usage_internal
+    flat_hash_map
+    flat_hash_set
+    function_ref
+    graphcycles_internal
+    hash
+    hash_function_defaults
+    hash_generator_testing
+    hash_policy_testing
+    hash_policy_traits
+    hash_testing
+    hashtable_debug
+    hashtable_debug_hooks
+    hashtablez_sampler
+    have_sse
+    hdrs
+    inlined_vector
+    inlined_vector_internal
+    int128
+    kernel_timeout_internal
+    layout
+    leak_check
+    leak_check_api_disabled_for_testing
+    leak_check_api_enabled_for_testing
+    leak_check_disable
+    log_severity
+    main_lib
+    malloc_internal
+    memory
+    meta
+    node_hash_map
+    node_hash_policy
+    node_hash_set
+    numeric
+    optional
+    per_thread_sem_test_common
+    periodic_sampler
+    pow10_helper
+    pretty_function
+    random_bit_gen_ref
+    random_distributions
+    random_internal_distribution_caller
+    random_internal_distribution_test_util
+    random_internal_explicit_seed_seq
+    random_internal_fast_uniform_bits
+    random_internal_fastmath
+    random_internal_generate_real
+    random_internal_iostream_state_saver
+    random_internal_mock_helpers
+    random_internal_mock_overload_set
+    random_internal_nonsecure_base
+    random_internal_pcg_engine
+    random_internal_platform
+    random_internal_pool_urbg
+    random_internal_randen
+    random_internal_randen_engine
+    random_internal_randen_hwaes
+    random_internal_randen_hwaes_impl
+    random_internal_randen_slow
+    random_internal_salted_seed_seq
+    random_internal_seed_material
+    random_internal_sequence_urbg
+    random_internal_traits
+    random_internal_uniform_helper
+    random_internal_wide_multiply
+    random_mocking_bit_gen
+    random_random
+    random_seed_gen_exception
+    random_seed_sequences
+    raw_hash_map
+    raw_hash_set
+    raw_logging_internal
+    scoped_set_env
+    span
+    spinlock_test_common
+    spinlock_wait
+    spy_hash_state
+    stack_consumption
+    stacktrace
+    status
+    str_format
+    str_format_internal
+    strerror
+    strings
+    strings_internal
+    symbolize
+    synchronization
+    test_instance_tracker
+    thread_pool
+    throw_delegate
+    time
+    time_internal_test_util
+    time_zone
+    tracked
+    type_traits
+    unordered_map_constructor_test
+    unordered_map_lookup_test
+    unordered_map_members_test
+    unordered_map_modifiers_test
+    unordered_set_constructor_test
+    unordered_set_lookup_test
+    unordered_set_members_test
+    unordered_set_modifiers_test
+    utility
+    variant
+  )
+  set(_ABSL_LIBRARIES ${_ABSL_LIBRARY_NAMES})
+  foreach(_LIBRARY ${_ABSL_LIBRARY_NAMES})
+    list(APPEND _ABSL_LIBRARIES "absl::${LIBRARY}")
+  endforeach()
+  set(ABSL_LIBRARIES ${ABSL_LIBRARIES} CACHE STRING "absl libs")
+endif()
diff --git a/tensorflow/lite/tools/cmake/modules/eigen.cmake b/tensorflow/lite/tools/cmake/modules/eigen.cmake
new file mode 100644
index 0000000..6ad7949
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/eigen.cmake
@@ -0,0 +1,95 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET eigen OR eigen_POPULATED)
+  return()
+endif()
+
+include(OverridableFetchContent)
+
+OverridableFetchContent_Declare(
+  eigen
+  GIT_REPOSITORY https://gitlab.com/libeigen/eigen
+  # TODO: Verify this is the version required by TFLite
+  GIT_TAG b9362fb8f76fbba805b56afbc0f5de0a279631b5
+  # It's not currently (cmake 3.17) possible to shallow clone with a GIT TAG
+  # as cmake attempts to git checkout the commit hash after the clone
+  # which doesn't work as it's a shallow clone hence a different commit hash.
+  # https://gitlab.kitware.com/cmake/cmake/-/issues/17770
+  # GIT_SHALLOW TRUE
+  GIT_PROGRESS TRUE
+  PREFIX "${CMAKE_BINARY_DIR}"
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/eigen"
+  LICENSE_FILE "COPYING.MPL2"
+)
+OverridableFetchContent_GetProperties(eigen)
+if(NOT eigen_POPULATED)
+  OverridableFetchContent_Populate(eigen)
+endif()
+
+# Patch Eigen to disable Fortran compiler check for BLAS and LAPACK tests.
+if(NOT EIGEN_DISABLED_FORTRAN_COMPILER_CHECK)
+  file(WRITE "${eigen_SOURCE_DIR}/cmake/language_support.cmake" "
+      function(workaround_9220 language language_works)
+        set(\${language_works} OFF PARENT_SCOPE)
+      endfunction()"
+  )
+endif()
+# Patch Eigen to disable benchmark suite.
+if(NOT EIGEN_BUILD_BTL)
+  file(WRITE "${eigen_SOURCE_DIR}/bench/spbench/CMakeLists.txt" "")
+endif()
+
+set(EIGEN_DISABLED_FORTRAN_COMPILER_CHECK ON CACHE BOOL "Disabled Fortran")
+
+set(EIGEN_LEAVE_TEST_IN_ALL_TARGET OFF CACHE BOOL
+  "Remove tests from all target."
+)
+set(BUILD_TESTING OFF CACHE BOOL "Disable tests.")
+set(EIGEN_TEST_CXX11 OFF CACHE BOOL "Disable tests of C++11 features.")
+set(EIGEN_BUILD_BTL OFF CACHE BOOL "Disable benchmark suite.")
+set(EIGEN_BUILD_PKGCONFIG OFF CACHE BOOL "Disable pkg-config.")
+set(EIGEN_SPLIT_LARGE_TESTS OFF CACHE BOOL "Disable test splitting.")
+set(EIGEN_DEFAULT_TO_ROW_MAJOR OFF CACHE BOOL
+  "Disable row-major matrix storage"
+)
+set(EIGEN_TEST_NOQT ON CACHE BOOL "Disable Qt support in tests.")
+set(EIGEN_TEST_SSE2 OFF CACHE BOOL "Disable SSE2 test.")
+set(EIGEN_TEST_SSE3 OFF CACHE BOOL "Disable SSE3 test.")
+set(EIGEN_TEST_SSSE3 OFF CACHE BOOL "Disable SSSE3 test.")
+set(EIGEN_TEST_SSE4_1 OFF CACHE BOOL "Disable SSE4.1 test.")
+set(EIGEN_TEST_SSE4_2 OFF CACHE BOOL "Disable SSE4.2 test.")
+set(EIGEN_TEST_AVX OFF CACHE BOOL "Disable AVX test.")
+set(EIGEN_TEST_FMA OFF CACHE BOOL "Disable FMA test.")
+set(EIGEN_TEST_AVX512 OFF CACHE BOOL "Disable AVX512 test.")
+set(EIGEN_TEST_F16C OFF CACHE BOOL "Disable F16C test.")
+set(EIGEN_TEST_ALTIVEC OFF CACHE BOOL "Disable AltiVec test.")
+set(EIGEN_TEST_VSX OFF CACHE BOOL "Disable VSX test.")
+set(EIGEN_TEST_MSA OFF CACHE BOOL "Disable MSA test.")
+set(EIGEN_TEST_NEON OFF CACHE BOOL "Disable NEON test.")
+set(EIGEN_TEST_NEON64 OFF CACHE BOOL "Disable NEON64 test.")
+set(EIGEN_TEST_Z13 OFF CACHE BOOL "Disable Z13 test.")
+set(EIGEN_TEST_Z14 OFF CACHE BOOL "Disable Z14 test.")
+set(EIGEN_TEST_OPENMP OFF CACHE BOOL "Disable OpenMP test.")
+set(EIGEN_TEST_NO_EXPLICIT_VECTORIZATION OFF CACHE BOOL "Disable vectorization")
+set(EIGEN_TEST_X87 OFF CACHE BOOL "Disable X87 instructions test")
+set(EIGEN_TEST_32BIT OFF CACHE BOOL "Disable 32-bit instructions test")
+set(EIGEN_TEST_NO_EXPLICIT_ALIGNMENT OFF CACHE BOOL "Disable alignment test")
+set(EIGEN_TEST_NO_EXCEPTIONS OFF CACHE BOOL "Disable alignment test")
+set(EIGEN_TEST_SYCL OFF CACHE BOOL "Disable Sycl test")
+set(EIGEN_SYCL_TRISYCL OFF CACHE BOOL "Disable triSYCL test")
+# Make sure only MPL2.0 or more permissively licensed code is included.
+add_compile_definitions(EIGEN_MPL2_ONLY)
+add_subdirectory("${eigen_SOURCE_DIR}" "${eigen_BINARY_DIR}" EXCLUDE_FROM_ALL)
diff --git a/tensorflow/lite/tools/cmake/modules/farmhash.cmake b/tensorflow/lite/tools/cmake/modules/farmhash.cmake
new file mode 100644
index 0000000..09ec7bd
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/farmhash.cmake
@@ -0,0 +1,48 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET farmhash OR farmhash_POPULATED)
+  return()
+endif()
+
+include(OverridableFetchContent)
+
+OverridableFetchContent_Declare(
+  farmhash
+  GIT_REPOSITORY https://github.com/google/farmhash
+  # TODO: Reference the source of this.
+  GIT_TAG 816a4ae622e964763ca0862d9dbd19324a1eaf45
+  # It's not currently possible to shallow clone with a GIT TAG
+  # as cmake attempts to git checkout the commit hash after the clone
+  # which doesn't work as it's a shallow clone hence a different commit hash.
+  # https://gitlab.kitware.com/cmake/cmake/-/issues/17770
+  # GIT_SHALLOW TRUE
+  GIT_PROGRESS TRUE
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/farmhash"
+)
+OverridableFetchContent_GetProperties(farmhash)
+if(NOT farmhash_POPULATED)
+  OverridableFetchContent_Populate(farmhash)
+endif()
+
+set(FARMHASH_SOURCE_DIR "${farmhash_SOURCE_DIR}" CACHE PATH
+  "Source directory for the CMake project."
+)
+
+add_subdirectory(
+  "${CMAKE_CURRENT_LIST_DIR}/farmhash"
+  "${farmhash_BINARY_DIR}"
+  EXCLUDE_FROM_ALL
+)
diff --git a/tensorflow/lite/tools/cmake/modules/farmhash/CMakeLists.txt b/tensorflow/lite/tools/cmake/modules/farmhash/CMakeLists.txt
new file mode 100644
index 0000000..7029926
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/farmhash/CMakeLists.txt
@@ -0,0 +1,39 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+project(farmhash CXX)
+
+set(FARMHASH_SOURCE_DIR "" CACHE PATH
+  "Directory that contains the farmhash project"
+)
+if(NOT FARMHASH_SOURCE_DIR)
+  message(FATAL_ERROR "Must specify source directory")
+endif()
+
+# Transcribed from farmhash/src/Makefile.am
+include(CheckCXXSourceCompiles)
+check_cxx_source_compiles(
+  "int main(int argc, char* argv[]) { return (int)__builtin_expect(0, 0); }"
+  FARMHASH_HAS_BUILTIN_EXPECT
+)
+
+add_library(farmhash
+  "${FARMHASH_SOURCE_DIR}/src/farmhash.cc"
+  "${FARMHASH_SOURCE_DIR}/src/farmhash.h"
+)
+target_include_directories(farmhash PUBLIC "${FARMHASH_SOURCE_DIR}/src")
+if(NOT FARMHASH_HAS_BUILTIN_EXPECT)
+  target_compile_definitions(farmhash PUBLIC -DFARMHASH_NO_BUILTIN_EXPECT)
+endif()
diff --git a/tensorflow/lite/tools/cmake/modules/fft2d.cmake b/tensorflow/lite/tools/cmake/modules/fft2d.cmake
new file mode 100644
index 0000000..93ac8c1
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/fft2d.cmake
@@ -0,0 +1,41 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET fft2d OR fft2d_POPULATED)
+  return()
+endif()
+
+include(OverridableFetchContent)
+
+OverridableFetchContent_Declare(
+  fft2d
+  URL https://storage.googleapis.com/mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz
+  # TODO: Reference where this comes from.
+  URL_HASH SHA256=ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/fft2d"
+  LICENSE_FILE "readme2d.txt"
+  LICENSE_URL "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.html"
+)
+OverridableFetchContent_GetProperties(fft2d)
+if(NOT fft2d_POPULATED)
+  OverridableFetchContent_Populate(fft2d)
+endif()
+
+set(FFT2D_SOURCE_DIR "${fft2d_SOURCE_DIR}" CACHE PATH "fft2d source")
+add_subdirectory(
+  "${CMAKE_CURRENT_LIST_DIR}/fft2d"
+  "${fft2d_BINARY_DIR}"
+  EXCLUDE_FROM_ALL
+)
diff --git a/tensorflow/lite/tools/cmake/modules/fft2d/CMakeLists.txt b/tensorflow/lite/tools/cmake/modules/fft2d/CMakeLists.txt
new file mode 100644
index 0000000..e7a5ed9
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/fft2d/CMakeLists.txt
@@ -0,0 +1,54 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+project(fft2d C)
+
+set(FFT2D_SOURCE_DIR "" CACHE PATH
+  "Directory that contains the fft2d project"
+)
+if(NOT FFT2D_SOURCE_DIR)
+  message(FATAL_ERROR "Must specify source directory")
+endif()
+
+# fft2d doesn't have a CMake project so define it here transcribed from
+# sample2d/Makefile.
+
+# A developer should link this library if they haven't provided their own
+# implementation of these allocation methods.
+add_library(fft2d_alloc
+  "${FFT2D_SOURCE_DIR}/alloc.c"
+  "${FFT2D_SOURCE_DIR}/alloc.h"
+)
+target_include_directories(fft2d_alloc PUBLIC "${FFT2D_SOURCE_DIR}")
+
+# Requires implementation of fft2d_alloc.
+add_library(fft2d_fft4f2d "${FFT2D_SOURCE_DIR}/fft4f2d.c")
+target_include_directories(fft2d_fft4f2d PRIVATE "${FFT2D_SOURCE_DIR}")
+
+add_library(fft2d_fftsg "${FFT2D_SOURCE_DIR}/fftsg.c")
+
+# Requires implementation of fft2d_alloc.
+add_library(fft2d_fftsg2d "${FFT2D_SOURCE_DIR}/fftsg2d.c")
+target_link_libraries(fft2d_fftsg2d fft2d_fftsg)
+target_include_directories(fft2d_fftsg2d PRIVATE "${FFT2D_SOURCE_DIR}")
+
+# Requires implementation of fft2d_alloc.
+add_library(fft2d_fftsg3d "${FFT2D_SOURCE_DIR}/fftsg3d.c")
+target_link_libraries(fft2d_fftsg3d fft2d_fftsg)
+target_include_directories(fft2d_fftsg3d PRIVATE "${FFT2D_SOURCE_DIR}")
+
+add_library(fft2d_shrtdct "${FFT2D_SOURCE_DIR}/shrtdct.c")
+
+add_library(fft2d ALIAS fft2d_fftsg2d)
diff --git a/tensorflow/lite/tools/cmake/modules/flatbuffers.cmake b/tensorflow/lite/tools/cmake/modules/flatbuffers.cmake
new file mode 100644
index 0000000..38380ca
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/flatbuffers.cmake
@@ -0,0 +1,43 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET flatbuffers OR flatbuffers_POPULATED)
+  return()
+endif()
+
+include(FetchContent)
+
+OverridableFetchContent_Declare(
+  flatbuffers
+  GIT_REPOSITORY https://github.com/google/flatbuffers
+  GIT_TAG v1.12.0 # TODO: What version does TFLite need?
+  GIT_SHALLOW TRUE
+  GIT_PROGRESS TRUE
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/flatbuffers"
+)
+OverridableFetchContent_GetProperties(flatbuffers)
+if(NOT flatbuffers_POPULATED)
+  OverridableFetchContent_Populate(flatbuffers)
+endif()
+
+# Required for Windows, since it has macros called min & max which
+# clashes with std::min
+add_definitions(-DNOMINMAX=1)
+add_subdirectory(
+  "${flatbuffers_SOURCE_DIR}"
+  "${flatbuffers_BINARY_DIR}"
+  EXCLUDE_FROM_ALL
+)
+remove_definitions(-DNOMINMAX)
diff --git a/tensorflow/lite/tools/cmake/modules/gemmlowp.cmake b/tensorflow/lite/tools/cmake/modules/gemmlowp.cmake
new file mode 100644
index 0000000..a0483ab
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/gemmlowp.cmake
@@ -0,0 +1,45 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET gemmlowp OR gemmlowp_POPULATED)
+  return()
+endif()
+
+include(OverridableFetchContent)
+
+OverridableFetchContent_Declare(
+  gemmlowp
+  GIT_REPOSITORY https://github.com/google/gemmlowp
+  GIT_TAG fda83bdc38b118cc6b56753bd540caa49e570745
+  # It's not currently (cmake 3.17) possible to shallow clone with a GIT TAG
+  # as cmake attempts to git checkout the commit hash after the clone
+  # which doesn't work as it's a shallow clone hence a different commit hash.
+  # https://gitlab.kitware.com/cmake/cmake/-/issues/17770
+  # GIT_SHALLOW TRUE
+  GIT_PROGRESS TRUE
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/gemmlowp"
+)
+
+OverridableFetchContent_GetProperties(gemmlowp)
+if(NOT gemmlowp_POPULATED)
+  OverridableFetchContent_Populate(gemmlowp)
+endif()
+
+set(GEMMLOWP_SOURCE_DIR "${gemmlowp_SOURCE_DIR}" CACHE PATH "Source directory")
+add_subdirectory(
+  "${CMAKE_CURRENT_LIST_DIR}/gemmlowp"
+  "${gemmlowp_BINARY_DIR}"
+  EXCLUDE_FROM_ALL
+)
diff --git a/tensorflow/lite/tools/cmake/modules/gemmlowp/CMakeLists.txt b/tensorflow/lite/tools/cmake/modules/gemmlowp/CMakeLists.txt
new file mode 100644
index 0000000..0aa5ae1
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/gemmlowp/CMakeLists.txt
@@ -0,0 +1,87 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+project(gemmlowp CXX)
+
+option(GEMMLOWP_ADD_HEADERS_TO_TARGETS OFF
+  "Whether to add sources to gemmlowp's interface library targets.
+   This will cause all users of these libraries to also include these headers"
+)
+
+set(GEMMLOWP_SOURCE_DIR "" CACHE PATH
+  "Directory that contains the gemmlowp project"
+)
+if(NOT GEMMLOWP_SOURCE_DIR)
+  message(FATAL_ERROR "Must specify source directory")
+endif()
+
+# gemmlowp doesn't have a CMake project so this is transcribed from
+# gemmlowp/BUILD.
+
+file(GLOB GEMMLOWP_EIGHTBITINT_HEADERS
+  "${GEMMLOWP_SOURCE_DIR}/eight_bit_int_gemm/*.h"
+  )
+file(GLOB GEMMLOWP_EIGHTBITINT_SOURCES
+  "${GEMMLOWP_SOURCE_DIR}/eight_bit_int_gemm/*.cc"
+)
+file(GLOB GEMMLOWP_FIXEDPOINT_HEADERS "${GEMMLOWP_SOURCE_DIR}/fixedpoint/*.h")
+file(GLOB GEMMLOWP_INTERNAL_HEADERS "${GEMMLOWP_SOURCE_DIR}/internal/*.h")
+file(GLOB GEMMLOWP_META_HEADERS "${GEMMLOWP_SOURCE_DIR}/meta/*.h")
+file(GLOB GEMMLOWP_PROFILING_HEADERS "${GEMMLOWP_SOURCE_DIR}/profiling/*.h")
+file(GLOB GEMMLOWP_PUBLIC_HEADERS "${GEMMLOWP_SOURCE_DIR}/public/*.h")
+
+set(GEMMLOWP_PRIVATE_HEADERS "")
+list(APPEND GEMMLOWP_PRIVATE_HEADERS ${GEMMLOWP_FIXEDPOINT_HEADERS})
+list(APPEND GEMMLOWP_PRIVATE_HEADERS ${GEMMLOWP_INTERNAL_HEADERS})
+
+add_library(gemmlowp_private INTERFACE)
+if(GEMMLOWP_ADD_HEADERS_TO_TARGETS)
+  target_sources(gemmlowp_private INTERFACE ${GEMMLOWP_PRIVATE_HEADERS})
+endif()
+target_include_directories(gemmlowp_private INTERFACE "${GEMMLOWP_SOURCE_DIR}")
+
+add_library(gemmlowp INTERFACE)
+if(GEMMLOWP_ADD_HEADERS_TO_TARGETS)
+  target_sources(gemmlowp INTERFACE ${GEMMLOWP_PUBLIC_HEADERS})
+endif()
+target_include_directories(gemmlowp INTERFACE "${GEMMLOWP_SOURCE_DIR}/public")
+target_link_libraries(gemmlowp INTERFACE gemmlowp_private)
+
+add_library(gemmlowp_eight_bit_int_gemm
+  ${GEMMLOWP_EIGHTBITINT_SOURCES}
+  ${GEMMLOWP_EIGHTBITINT_HEADERS}
+)
+target_include_directories(gemmlowp_eight_bit_int_gemm
+  PUBLIC "${GEMMLOWP_SOURCE_DIR}/eight_bit_int_gemm"
+)
+
+add_library(gemmlowp_fixedpoint INTERFACE)
+if(GEMMLOWP_ADD_HEADERS_TO_TARGETS)
+  target_sources(gemmlowp_fixedpoint INTERFACE ${GEMMLOWP_FIXEDPOINT_HEADERS})
+endif()
+target_include_directories(gemmlowp_fixedpoint
+  INTERFACE "${GEMMLOWP_SOURCE_DIR}/fixedpoint"
+)
+target_link_libraries(gemmlowp_fixedpoint INTERFACE gemmlowp_private)
+
+add_library(gemmlowp_profiler INTERFACE)
+if(GEMMLOWP_ADD_HEADERS_TO_TARGETS)
+  target_sources(gemmlowp_profiler INTERFACE ${GEMMLOWP_PROFILING_HEADERS})
+endif()
+target_include_directories(gemmlowp_profiler
+  INTERFACE "${GEMMLOWP_SOURCE_DIR}/profiling"
+)
+
+
diff --git a/tensorflow/lite/tools/cmake/modules/neon2sse.cmake b/tensorflow/lite/tools/cmake/modules/neon2sse.cmake
new file mode 100644
index 0000000..505835b
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/neon2sse.cmake
@@ -0,0 +1,40 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+include(ExternalProject)
+
+if(TARGET neon2sse OR neon2sse_POPULATED)
+  return()
+endif()
+
+OverridableFetchContent_Declare(
+  neon2sse
+  GIT_REPOSITORY https://github.com/intel/ARM_NEON_2_x86_SSE
+  GIT_TAG master
+  GIT_SHALLOW TRUE
+  GIT_PROGRESS TRUE
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/neon2sse"
+)
+
+OverridableFetchContent_GetProperties(neon2sse)
+if(NOT neon2sse_POPULATED)
+  OverridableFetchContent_Populate(neon2sse)
+endif()
+
+add_subdirectory(
+  "${neon2sse_SOURCE_DIR}"
+  "${neon2sse_BINARY_DIR}"
+  EXCLUDE_FROM_ALL
+)
diff --git a/tensorflow/lite/tools/cmake/modules/ruy.cmake b/tensorflow/lite/tools/cmake/modules/ruy.cmake
new file mode 100644
index 0000000..02a99cd
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/ruy.cmake
@@ -0,0 +1,41 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(TARGET ruy OR ruy_POPULATED)
+  return()
+endif()
+
+include(OverridableFetchContent)
+
+OverridableFetchContent_Declare(
+  ruy
+  GIT_REPOSITORY https://github.com/google/ruy
+  GIT_TAG master # TODO
+  GIT_SHALLOW TRUE
+  GIT_PROGRESS TRUE
+  SOURCE_DIR "${CMAKE_BINARY_DIR}/ruy"
+)
+OverridableFetchContent_GetProperties(ruy)
+if(NOT ruy_POPULATED)
+  OverridableFetchContent_Populate(ruy)
+endif()
+
+set(RUY_SOURCE_DIR "${ruy_SOURCE_DIR}" CACHE PATH "RUY source directory")
+
+add_subdirectory(
+  "${CMAKE_CURRENT_LIST_DIR}/ruy"
+  "${ruy_BINARY_DIR}"
+  EXCLUDE_FROM_ALL
+)
diff --git a/tensorflow/lite/tools/cmake/modules/ruy/CMakeLists.txt b/tensorflow/lite/tools/cmake/modules/ruy/CMakeLists.txt
new file mode 100644
index 0000000..d88d047
--- /dev/null
+++ b/tensorflow/lite/tools/cmake/modules/ruy/CMakeLists.txt
@@ -0,0 +1,38 @@
+#
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+cmake_minimum_required(VERSION 3.16)
+
+project(ruy CXX)
+
+set(CMAKE_CXX_STANDARD 14)  # Some components require C++14.
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+set(RUY_SOURCE_DIR "" CACHE PATH
+  "Directory that contains the RUY project"
+)
+if(NOT RUY_SOURCE_DIR)
+  message(FATAL_ERROR "Must specify source directory")
+endif()
+
+file(GLOB RUY_SOURCES "${RUY_SOURCE_DIR}/ruy/*.*")
+list(FILTER RUY_SOURCES INCLUDE REGEX ".*\\.(c|cc|h)$")
+list(FILTER RUY_SOURCES EXCLUDE REGEX ".*(_test)\\.(c|cc|h)$")
+list(FILTER RUY_SOURCES EXCLUDE REGEX ".*/(benchmark|example|test_.*)\.cc$")
+list(FILTER RUY_SOURCES EXCLUDE REGEX ".*/gtest_wrapper\\.h$")
+
+add_library(ruy ${RUY_SOURCES})
+target_include_directories(ruy PUBLIC "${RUY_SOURCE_DIR}")
+
diff --git a/tensorflow/lite/tools/list_flex_ops.cc b/tensorflow/lite/tools/list_flex_ops.cc
index 6fdfd06..b4db4f7 100644
--- a/tensorflow/lite/tools/list_flex_ops.cc
+++ b/tensorflow/lite/tools/list_flex_ops.cc
@@ -20,7 +20,7 @@
 #include <vector>
 
 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op.h"
diff --git a/tensorflow/lite/tools/list_flex_ops_no_kernel.cc b/tensorflow/lite/tools/list_flex_ops_no_kernel.cc
index ea4d41c..68d40be 100644
--- a/tensorflow/lite/tools/list_flex_ops_no_kernel.cc
+++ b/tensorflow/lite/tools/list_flex_ops_no_kernel.cc
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "include/json/json.h"
+#include "json/json.h"
 #include "tensorflow/lite/tools/list_flex_ops.h"
 
 namespace tflite {
diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile
index c7ddff5..eb07503 100644
--- a/tensorflow/lite/tools/make/Makefile
+++ b/tensorflow/lite/tools/make/Makefile
@@ -315,9 +315,6 @@
 $(OBJDIR)%.o: %.c
 	@mkdir -p $(dir $@)
 	$(CC) $(CFLAGS) $(INCLUDES) -c $< -o $@
-$(OBJDIR)%.o: %.cpp
-	@mkdir -p $(dir $@)
-	$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
 
 # The target that's compiled if there's no command-line arguments.
 all: $(LIB_PATH)  $(MINIMAL_BINARY) $(BENCHMARK_BINARY) $(BENCHMARK_PERF_OPTIONS_BINARY)
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index b61cebd..f517805 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -265,6 +265,7 @@
 tensorflow/third_party/toolchains/remote_config/configs.bzl
 tensorflow/third_party/toolchains/remote_config/containers.bzl
 tensorflow/third_party/toolchains/remote_config/rbe_config.bzl
+tensorflow/third_party/typing_extensions.BUILD
 tensorflow/third_party/wrapt.BUILD
 tensorflow/third_party/zlib.BUILD
 tensorflow/tools/build_info/BUILD
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ac58ae5..f39797f 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -678,20 +678,21 @@
         "client/tf_session_helper.h",
         "lib/core/numpy.h",
         "lib/core/safe_ptr.h",
+        "lib/core/safe_pyobject_ptr.h",
         "//tensorflow/c:headers",
-        "//tensorflow/c:pywrap_required_hdrs",
         "//tensorflow/c/eager:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
         "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
-        "//tensorflow/core/framework:pywrap_required_hdrs",
     ],
     module_name = "_pywrap_tf_session",
     deps = [
         ":pybind11_lib",
         ":pybind11_status",
+        "//tensorflow/core/framework:pywrap_required_hdrs",
         "//third_party/py/numpy:headers",
+        "//tensorflow/c:pywrap_required_hdrs",
         "@pybind11",
         "//third_party/python_runtime:headers",
         "//tensorflow/core:protos_all_cc",
@@ -880,6 +881,7 @@
     hdrs = [
         "lib/core/ndarray_tensor.h",
         "lib/core/safe_ptr.h",
+        "lib/core/safe_pyobject_ptr.h",
         ":py_exception_registry_hdr",
         "//tensorflow/c:checkpoint_reader_hdrs",
         "//tensorflow/c:headers",
@@ -940,12 +942,17 @@
     ],
 )
 
+# TODO(edloper): Remove unused dependency on safe_ptr.  (Blocker: there are
+# targets that depend are relying on cpp_python_util to pull in safe_ptr's
+# third_party/tensorflow/c:c_api_no_xla dependency, which registers
+# ops/gradients, rather than depending on it themselves.)
 cc_library(
     name = "cpp_python_util",
     srcs = ["util/util.cc"],
     hdrs = ["util/util.h"],
     deps = [
         ":safe_ptr",
+        ":safe_pyobject_ptr",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//third_party/python_runtime:headers",
@@ -1001,6 +1008,15 @@
 )
 
 cc_library(
+    name = "safe_pyobject_ptr",
+    srcs = ["lib/core/safe_pyobject_ptr.cc"],
+    hdrs = ["lib/core/safe_pyobject_ptr.h"],
+    deps = [
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+cc_library(
     name = "safe_ptr",
     srcs = [
         "lib/core/safe_ptr.cc",
@@ -1008,6 +1024,7 @@
     ],
     hdrs = ["lib/core/safe_ptr.h"],
     deps = [
+        ":safe_pyobject_ptr",
         "//tensorflow/c:c_api_no_xla",
         "//third_party/python_runtime:headers",
     ],
@@ -1021,8 +1038,8 @@
         "lib/core/ndarray_tensor_bridge.h",
         "lib/core/numpy.h",
         "lib/core/safe_ptr.h",
+        "lib/core/safe_pyobject_ptr.h",
         "//tensorflow/c:headers",
-        "//tensorflow/c:pywrap_required_hdrs",
         "//tensorflow/c/eager:headers",
     ],
     features = [
@@ -1033,6 +1050,7 @@
     ]),
     deps = [
         ":numpy_lib",
+        "//tensorflow/c:pywrap_required_hdrs",
         "//tensorflow/c:tf_status_headers",
         "//tensorflow/core:framework_internal_headers_lib",
         "//tensorflow/core/common_runtime:core_cpu_headers_lib",
@@ -1628,12 +1646,46 @@
 )
 
 cc_library(
+    name = "py_context_manager",
+    srcs = ["framework/py_context_manager.cc"],
+    hdrs = ["framework/py_context_manager.h"],
+    deps = [
+        ":safe_pyobject_ptr",
+        "//tensorflow/core:lib",  # for core/platform/logging.h
+        "//third_party/python_runtime:headers",
+    ],
+)
+
+# Pybind extension used by py_context_manager_test.
+tf_python_pybind_extension(
+    name = "_py_context_manager",
+    srcs = ["framework/py_context_manager_pybind.cc"],
+    module_name = "_py_context_manager",
+    deps = [
+        ":py_context_manager",
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
+tf_py_test(
+    name = "py_context_manager_test",
+    srcs = ["framework/py_context_manager_test.py"],
+    python_version = "PY3",
+    tags = ["no_pip"],
+    tfrt_enabled = True,
+    deps = [
+        ":_py_context_manager",
+    ],
+)
+
+cc_library(
     name = "op_def_util_cc",
     srcs = ["framework/op_def_util.cc"],
     hdrs = ["framework/op_def_util.h"],
     deps = [
         ":cpp_python_util",
-        ":safe_ptr",
+        ":safe_pyobject_ptr",
         "//tensorflow/core:protos_all_cc",
         "@com_google_absl//absl/strings",
     ],
@@ -1644,6 +1696,8 @@
 # depending on that target adds dependencies that register objects; and since the
 # extension is built as a shared object in some kokoro tests, this causes those objects
 # to get registered multiple times (which fails).
+# TODO(edloper): Simplify this, once cpp_python_util is changed to not depend on
+# safe_ptr (which transitively depends on third_party/tensorflow/c:c_api_no_xla).
 tf_python_pybind_extension(
     name = "_op_def_util",
     srcs = [
@@ -1653,6 +1707,7 @@
     hdrs = [
         "framework/op_def_util.h",
         "lib/core/safe_ptr.h",
+        "lib/core/safe_pyobject_ptr.h",
         "util/util.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
@@ -2994,6 +3049,7 @@
     ],
     deps = [
         "//tensorflow/c/kernels:histogram_summary_op_lib",
+        "//tensorflow/c/kernels:merge_summary_op_lib",
         "//tensorflow/c/kernels:summary_op_lib",
         "//tensorflow/core:logging_ops_op_lib",
     ],
@@ -3421,6 +3477,25 @@
 )
 
 tf_py_test(
+    name = "collective_ops_multi_worker_test",
+    size = "medium",
+    srcs = ["ops/collective_ops_multi_worker_test.py"],
+    python_version = "PY3",
+    tags = ["no_rocm"],
+    tfrt_enabled = False,
+    deps = [
+        ":collective_ops",
+        ":constant_op",
+        ":errors",
+        "//tensorflow/python/distribute:multi_process_runner",
+        "//tensorflow/python/distribute:multi_worker_test_base",
+        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:test",
+    ],
+)
+
+tf_py_test(
     name = "collective_ops_xla_test",
     size = "small",
     srcs = ["ops/collective_ops_xla_test.py"],
@@ -8310,6 +8385,7 @@
     srcs = ["mlir_wrapper.cc"],
     hdrs = [
         "lib/core/safe_ptr.h",
+        "lib/core/safe_pyobject_ptr.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
         "//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs",
@@ -8341,27 +8417,28 @@
     srcs = ["tfe_wrapper.cc"],
     hdrs = [
         "lib/core/safe_ptr.h",
+        "lib/core/safe_pyobject_ptr.h",
         "util/util.h",
         ":py_exception_registry_hdr",
         "//tensorflow/c:headers",
-        "//tensorflow/c:pywrap_required_hdrs",
         "//tensorflow/c/eager:headers",
         "//tensorflow/c/eager:pywrap_required_hdrs",
         "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
         "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
-        "//tensorflow/core/framework:pywrap_required_hdrs",
         "//tensorflow/python/eager:pywrap_required_hdrs",
     ],
     module_name = "_pywrap_tfe",
     deps = [
         ":pybind11_lib",
         ":pybind11_status",
+        "//tensorflow/core/framework:pywrap_required_hdrs",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/hash",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
+        "//tensorflow/c:pywrap_required_hdrs",
         "@pybind11",
         "//third_party/python_runtime:headers",
         "//tensorflow/c/experimental/saved_model/core:pywrap_required_hdrs",
@@ -8412,6 +8489,7 @@
     name = "_pywrap_parallel_device",
     srcs = [
         "lib/core/safe_ptr.h",
+        "lib/core/safe_pyobject_ptr.h",
         "//tensorflow/c:headers",
         "//tensorflow/c/eager:headers",
         "//tensorflow/c/eager/parallel_device:headers",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index b5acf23..ce26271 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -36,6 +36,7 @@
 
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
+from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 
 from tensorflow.python.eager import context
 
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index c3fc879..98f7664 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -60,10 +60,10 @@
   def _create_nonlocal_declarations(self, vars_):
     vars_ = set(vars_)
     results = []
-    global_vars = self.state[_Function].scope.globals
+    global_vars = self.state[_Function].scope.globals & vars_
 
     if global_vars:
-      results.append(gast.Global([str(v) for v in vars_]))
+      results.append(gast.Global([str(v) for v in global_vars]))
 
     nonlocal_vars = [
         v for v in vars_ if not v.is_composite() and v not in global_vars]
@@ -180,6 +180,7 @@
     defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
     live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
     live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+    fn_scope = self.state[_Function].scope
 
     basic_scope_vars = self._get_block_basic_vars(
         modified,
@@ -191,8 +192,9 @@
     # Variables that are modified inside the scope, but not defined
     # before entering it. Only simple variables must be defined. The
     # composite ones will be implicitly checked at runtime.
-    # This covers loop variables as well as variables that
-    undefined = tuple(v for v in modified - defined_in if not v.is_composite())
+    possibly_undefined = (
+        modified - defined_in - fn_scope.globals - fn_scope.nonlocals)
+    undefined = tuple(v for v in possibly_undefined if not v.is_composite())
 
     # Variables that are modified inside the scope, and depend on values outside
     # it.
diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index 87f59be..497b329 100644
--- a/tensorflow/python/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -38,6 +38,8 @@
 
 
 for_unaffected_global = None
+for_mixed_globals_nonglobals = None
+for_test_global_local = None
 
 
 class ControlFlowTestBase(converter_testing.TestCase):
@@ -76,6 +78,25 @@
     self.assertTransformedResult(f, constant_op.constant(5),
                                  (25, 5, 0, 5))
 
+  def test_mixed_globals_nonglobals(self):
+
+    def f(n):
+      global for_mixed_globals_nonglobals
+      i = 0
+      j = 0
+      for_mixed_globals_nonglobals = 0
+      while i < n:
+        while j < i:
+          j += 3
+        u = i + j  # 'u' is not defined within the inner loop
+        for_mixed_globals_nonglobals += u
+        i += 1
+        j = 0
+      return for_mixed_globals_nonglobals, i, j, n
+
+    self.assertTransformedResult(f, constant_op.constant(5),
+                                 (25, 5, 0, 5))
+
   def test_composite_state_complex(self):
 
     class TestClassX(object):
@@ -457,6 +478,23 @@
     self.assertTransformedResult(f, constant_op.constant(1), 5)
     self.assertTransformedResult(f, constant_op.constant(-1), -1)
 
+  def test_global_local(self):
+
+    def f(n):
+      if n > 0:
+        global for_test_global_local
+        if for_test_global_local is None:
+          for_test_global_local = 1
+        else:
+          for_test_global_local += 1
+        n += for_test_global_local
+      return n
+
+    tr = self.transform(f, control_flow)
+    assert for_test_global_local is None
+    self.assertEqual(tr(1), 2)
+    self.assertEqual(for_test_global_local, 1)
+
   def test_no_outputs(self):
 
     def f(n):
diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc
index 6bc8cb2..ac656d3 100644
--- a/tensorflow/python/client/tf_session_wrapper.cc
+++ b/tensorflow/python/client/tf_session_wrapper.cc
@@ -166,7 +166,7 @@
           return out_handle;
         });
   m.def("_TF_SetTarget", TF_SetTarget);
-  m.def("_TF_SetConfig", [](TF_SessionOptions* options, py::str proto) {
+  m.def("_TF_SetConfig", [](TF_SessionOptions* options, py::bytes proto) {
     tensorflow::Safe_TF_StatusPtr status =
         tensorflow::make_safe(TF_NewStatus());
     tensorflow::Safe_TF_BufferPtr buf =
@@ -398,7 +398,7 @@
   });
 
   m.def("SetHandleShapeAndType",
-        [](TF_Graph* graph, TF_Output output, py::str proto) {
+        [](TF_Graph* graph, TF_Output output, py::bytes proto) {
           tensorflow::Safe_TF_StatusPtr status =
               tensorflow::make_safe(TF_NewStatus());
           tensorflow::Safe_TF_BufferPtr buf =
@@ -614,7 +614,7 @@
         });
 
   m.def("TF_SetAttrValueProto", [](TF_OperationDescription* desc,
-                                   const char* attr_name, py::str proto) {
+                                   const char* attr_name, py::bytes proto) {
     tensorflow::Safe_TF_StatusPtr status =
         tensorflow::make_safe(TF_NewStatus());
     tensorflow::Safe_TF_BufferPtr buf =
@@ -673,7 +673,7 @@
   m.def("TF_DeleteBuffer", &TF_DeleteBuffer);
   m.def(
       "TF_NewBufferFromString",
-      [](py::str buffer_as_string) {
+      [](py::bytes buffer_as_string) {
         tensorflow::Safe_TF_BufferPtr buf = tensorflow::make_safe(
             ProtoStringToTFBuffer(buffer_as_string.ptr()));
         return TF_NewBufferFromString(buf.get()->data, buf.get()->length);
@@ -853,7 +853,7 @@
         py::call_guard<py::gil_scoped_release>());
 
   m.def("TF_FunctionSetAttrValueProto",
-        [](TF_Function* func, const char* attr_name, py::str proto) {
+        [](TF_Function* func, const char* attr_name, py::bytes proto) {
           tensorflow::Safe_TF_StatusPtr status =
               tensorflow::make_safe(TF_NewStatus());
           tensorflow::Safe_TF_BufferPtr buf =
@@ -887,7 +887,7 @@
 
   m.def(
       "TF_FunctionImportFunctionDef",
-      [](py::str proto) {
+      [](py::bytes proto) {
         tensorflow::Safe_TF_StatusPtr status =
             tensorflow::make_safe(TF_NewStatus());
         tensorflow::Safe_TF_BufferPtr buf =
@@ -991,7 +991,7 @@
 
   m.def(
       "TF_NewServer",
-      [](py::str proto) {
+      [](py::bytes proto) {
         tensorflow::Safe_TF_StatusPtr status =
             tensorflow::make_safe(TF_NewStatus());
         tensorflow::Safe_TF_BufferPtr buf =
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 7ce55b1..36b1043 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -33,7 +33,7 @@
 # This value changes every day with an automatic CL. It can be modified in code
 # via `forward_compatibility_horizon()` or with the environment variable
 # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 8, 19)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 8, 31)
 _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
 _FORWARD_COMPATIBILITY_DATE_NUMBER = None
 
diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD
index 387d379..6570e51 100644
--- a/tensorflow/python/compiler/tensorrt/BUILD
+++ b/tensorflow/python/compiler/tensorrt/BUILD
@@ -145,9 +145,8 @@
     ],
     python_version = "PY3",
     tags = [
-        "no_cuda11",  # TODO(b/165611343): Need to address the failures.
         "no_cuda_on_cpu_tap",
-        "no_oss",
+        "no_oss",  # TODO(b/165611343): Need to address the failures for CUDA 11 in OSS build.
         "no_rocm",
         "no_windows",
         "nomac",
@@ -170,6 +169,7 @@
     ],
     python_version = "PY3",
     tags = [
+        "no_cuda11",  # TODO(b/166308253): enable the test for CUDA 11.
         "no_cuda_on_cpu_tap",
         "no_oss",  # TODO(b/125290478): allow running in at least some OSS configurations.
         "no_pip",
diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py
index 195382c..9d2d3ab 100644
--- a/tensorflow/python/compiler/tensorrt/test/base_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/base_test.py
@@ -70,12 +70,6 @@
         ]
     }
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
 
@@ -136,12 +130,6 @@
     return conversion_params._replace(
         rewriter_config_template=rewrite_config_with_trt)
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase):
 
diff --git a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py
index 26e911e..3f2a546 100644
--- a/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/combined_nms_test.py
@@ -89,9 +89,6 @@
     }
 
   def ShouldRunTest(self, run_params):
-    # TODO(b/162447069): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, 'Skip test due to b/162447069')
     # There is no CombinedNonMaxSuppression op for GPU at the moment, so
     # calibration will fail.
     # TODO(laigd): fix this.
diff --git a/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py b/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py
index 9e71b9e..ccbaf9e 100644
--- a/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py
@@ -60,12 +60,6 @@
     """The relative tolerance to compare floating point results."""
     return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, 'Skip test due to b/162448349')
-    return super().ShouldRunTest(run_params)
-
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/compiler/tensorrt/test/conv2d_test.py b/tensorflow/python/compiler/tensorrt/test/conv2d_test.py
index 400c17b..df1adce 100644
--- a/tensorflow/python/compiler/tensorrt/test/conv2d_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/conv2d_test.py
@@ -114,12 +114,6 @@
       return 4e-02
     return super(Conv2DNCHWTest, self).ExpectedRelativeTolerance(run_params)
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase):
   """Testing conversion of Conv2D (data_format=NCHW) in TF-TRT conversion."""
@@ -143,12 +137,6 @@
     """Return the expected engines to build."""
     return ["TRTEngineOp_0"]
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase):
   """Testing conversion of strided Conv2D (data_format=NCHW)."""
@@ -180,12 +168,6 @@
     """Return the expected engines to build."""
     return ["TRTEngineOp_0"]
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 class Conv2DTranposeTest(trt_test.TfTrtIntegrationTestBase):
   """Testing conversion of conv2d_transpose (AKA Conv2DBackpropInput)"""
diff --git a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
index f02ad08..95dbe72 100644
--- a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
@@ -98,9 +98,6 @@
     return ["TRTEngineOp_0"]
 
   def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
     return (run_params.dynamic_engine and not trt_test.IsQuantizationMode(
         run_params.precision_mode)), "test dynamic engine and non-INT8"
 
diff --git a/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py b/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py
index c1f0a00..056edc3 100644
--- a/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py
@@ -67,12 +67,6 @@
     """The relative tolerance to compare floating point results."""
     return 0.1
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py
index 687a124..b57bee6 100644
--- a/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -72,12 +72,6 @@
     """Return the expected engines to build."""
     return ["TRTEngineOp_0", "TRTEngineOp_1"]
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162447069): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162447069")
-    return super().ShouldRunTest(run_params)
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py b/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py
index 39fee5c..f377fe8 100644
--- a/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py
@@ -61,12 +61,6 @@
         "TRTEngineOp_1": ["weights", "conv"]
     }
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162447069): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162447069")
-    return super().ShouldRunTest(run_params)
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
index d859407..000b231 100644
--- a/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
@@ -261,10 +261,6 @@
     if not is_tensorrt_enabled():
       return
 
-    # TODO(b/162447069): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return
-
     model_dir = test.test_src_dir_path(
         'python/compiler/tensorrt/test/testdata/mnist')
 
diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py
index 43034e8..8fd9606 100644
--- a/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py
@@ -76,12 +76,6 @@
     super(trt_test.TfTrtIntegrationTestBase, self).setUp()
     os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py
index 7b1f7e0..9d81cd6 100644
--- a/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py
@@ -67,12 +67,6 @@
     super(trt_test.TfTrtIntegrationTestBase, self).setUp()
     os.environ["TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION"] = "True"
 
-  def ShouldRunTest(self, run_params):
-    # TODO(b/162448349): Enable the test for TRT 7.1.3.
-    if trt_test.IsTensorRTVersionGreaterEqual(7, 1, 3):
-      return (False, "Skip test due to b/162448349")
-    return super().ShouldRunTest(run_params)
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py
index a091bdc..e7c84ee 100644
--- a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py
@@ -17,6 +17,7 @@
 from __future__ import division
 from __future__ import print_function
 
+from collections import namedtuple
 from absl.testing import parameterized
 
 from tensorflow.python.data.experimental.ops import compression_ops
@@ -25,14 +26,24 @@
 from tensorflow.python.data.util import structure
 from tensorflow.python.framework import combinations
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.platform import test
 
 
 def _test_objects():
+
+  Item = namedtuple("Item", "id name")
+
   return [
       combinations.NamedObject("int", 1),
       combinations.NamedObject("string", "dog"),
       combinations.NamedObject("tuple", (1, 1)),
+      combinations.NamedObject("nested_tuple", ((1, 1), (2, 2))),
+      combinations.NamedObject("named_tuple", Item(id=1, name="item1")),
+      combinations.NamedObject("unicode", "アヒル"),
+      combinations.NamedObject(
+          "nested_named_tuple",
+          (Item(id=1, name="item1"), Item(id=2, name="item2"))),
       combinations.NamedObject("int_string_tuple", (1, "dog")),
       combinations.NamedObject(
           "sparse",
@@ -50,11 +61,32 @@
   ]
 
 
+def _test_v2_eager_only_objects():
+  return [
+      combinations.NamedObject(
+          "ragged",
+          ragged_factory_ops.constant([[0, 1, 2, 3], [4, 5], [6, 7, 8], [9]])),
+      combinations.NamedObject(
+          "sparse_ragged_structured", {
+              "sparse":
+                  sparse_tensor.SparseTensorValue(
+                      indices=[[0, 0], [1, 2]],
+                      values=[1, 2],
+                      dense_shape=[3, 4]),
+              "ragged":
+                  ragged_factory_ops.constant([[0, 1, 2, 3], [9]])
+          })
+  ]
+
+
 class CompressionOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(
       combinations.times(test_base.default_test_combinations(),
-                         combinations.combine(element=_test_objects())))
+                         combinations.combine(element=_test_objects())) +
+      combinations.times(
+          test_base.v2_eager_only_combinations(),
+          combinations.combine(element=_test_v2_eager_only_objects())))
   def testCompression(self, element):
     element = element._obj
 
@@ -65,7 +97,10 @@
 
   @combinations.generate(
       combinations.times(test_base.default_test_combinations(),
-                         combinations.combine(element=_test_objects())))
+                         combinations.combine(element=_test_objects())) +
+      combinations.times(
+          test_base.v2_eager_only_combinations(),
+          combinations.combine(element=_test_v2_eager_only_objects())))
   def testDatasetCompression(self, element):
     element = element._obj
 
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py
index 16bb1ec..cab4126 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py
@@ -104,6 +104,28 @@
   return functools.reduce(reduce_fn, cases, [])
 
 
+def _disable_intra_op_parallelism_test_combinations():
+
+  def make_tensor_dataset():
+    return dataset_ops.Dataset.from_tensors(42)
+
+  def make_map_dataset():
+    return dataset_ops.Dataset.from_tensors(42).map(lambda x: x + 1)
+
+  cases = [
+      ("FromTensors", make_tensor_dataset, [42]),
+      ("Map", make_map_dataset, [43]),
+  ]
+
+  def reduce_fn(x, y):
+    name, dataset_fn, expected_output = y
+    return x + combinations.combine(
+        dataset_fn=combinations.NamedObject(name, dataset_fn),
+        expected_output=[expected_output])
+
+  return functools.reduce(reduce_fn, cases, [])
+
+
 class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   @combinations.generate(test_base.default_test_combinations())
@@ -186,17 +208,18 @@
     dataset = dataset.with_options(options)
     self.assertDatasetProduces(dataset, expected_output=[[0]])
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testOptimizationDisableIntraOpParallelism(self):
+  @combinations.generate(
+      combinations.times(test_base.default_test_combinations(),
+                         _disable_intra_op_parallelism_test_combinations()))
+  def testOptimizationDisableIntraOpParallelism(self, dataset_fn,
+                                                expected_output):
     os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "disable_intra_op_parallelism"
     os.environ["TF_JOB_NAME"] = "test_job"
 
-    dataset = dataset_ops.Dataset.range(10).map(lambda x: x+1)
+    dataset = dataset_fn()
     dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"]))
 
-    options = dataset_ops.Options()
-    dataset = dataset.with_options(options)
-    self.assertDatasetProduces(dataset, expected_output=list(range(1, 11)))
+    self.assertDatasetProduces(dataset, expected_output=expected_output)
 
     del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
     del os.environ["TF_JOB_NAME"]
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py
index cbff39b..e9a4d52 100644
--- a/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_with_slack_test.py
@@ -86,18 +86,16 @@
     self.assertDatasetProduces(dataset, range(1, 11))
 
   @combinations.generate(test_base.default_test_combinations())
-  def testErrorWithoutPrefetch(self):
-    """The rewrite fails if there is no prefetch() in the pipeline."""
+  def testNoErrorWithoutPrefetch(self):
+    """The rewrite should not fail if there is no prefetch() in the pipeline."""
     dataset = dataset_ops.Dataset.range(10)
     options = dataset_ops.Options()
     options.experimental_slack = True
     dataset = dataset.with_options(options)
-    with self.assertRaises(errors.InvalidArgumentError):
-      get_next = self.getNext(dataset)
-      self.evaluate(get_next())
+    self.assertDatasetProduces(dataset, range(10))
 
   @combinations.generate(test_base.default_test_combinations())
-  def testErrorWithInvalidDataset(self):
+  def testNoErrorWithInvalidDataset(self):
     """With a nested dataset op after prefetch, the rewrite should fail."""
     dataset = dataset_ops.Dataset.range(10)
     dataset = dataset.prefetch(1)
@@ -105,9 +103,7 @@
     options = dataset_ops.Options()
     options.experimental_slack = True
     dataset = dataset.with_options(options)
-    with self.assertRaises(errors.InvalidArgumentError):
-      get_next = self.getNext(dataset)
-      self.evaluate(get_next())
+    self.assertDatasetProduces(dataset, range(10))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc
index b268ba2..8ce904e 100644
--- a/tensorflow/python/data/experimental/service/server_lib_wrapper.cc
+++ b/tensorflow/python/data/experimental/service/server_lib_wrapper.cc
@@ -63,7 +63,7 @@
         }
         std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
         tensorflow::Status status =
-            tensorflow::data::NewDispatchServer(config, &server);
+            tensorflow::data::NewDispatchServer(config, server);
         tensorflow::MaybeRaiseFromStatus(status);
         return server;
       },
@@ -80,7 +80,7 @@
         }
         std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
         tensorflow::Status status =
-            tensorflow::data::NewWorkerServer(config, &server);
+            tensorflow::data::NewWorkerServer(config, server);
         tensorflow::MaybeRaiseFromStatus(status);
         return server;
       },
diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py
index 310a60b..6bc64dd 100644
--- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py
@@ -600,10 +600,8 @@
       _make_distributed_dataset(dataset, dispatcher)
       return dataset
 
-    with self.assertRaisesRegex(
-        errors.InvalidArgumentError, r"The `.distribute\(...\)` dataset "
-        "transformation is not supported within tf.data functions"):
-      ds = ds.interleave(interleave_fn, cycle_length=2)
+    ds = ds.interleave(interleave_fn, cycle_length=2)
+    self.assertDatasetProduces(ds, [0, 0, 1, 1])
 
   @combinations.generate(test_base.eager_only_combinations())
   def testDistributeNonStringAddresses(self):
diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py
index 0d820d9..31220c6 100644
--- a/tensorflow/python/data/kernel_tests/options_test.py
+++ b/tensorflow/python/data/kernel_tests/options_test.py
@@ -18,6 +18,9 @@
 from __future__ import division
 from __future__ import print_function
 
+import platform
+import sys
+
 from absl.testing import parameterized
 
 from tensorflow.python.data.experimental.ops import optimization_options
@@ -65,6 +68,9 @@
 
   @combinations.generate(test_base.default_test_combinations())
   def testOptionsTwiceSameOption(self):
+    if sys.version_info >= (3, 8) and platform.system() == "Windows":
+      # TODO(b/165013260): Fix this
+      self.skipTest("Test is currently broken on Windows with Python 3.8")
     options1 = dataset_ops.Options()
     options1.experimental_optimization.autotune = False
     options2 = dataset_ops.Options()
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index 3159536..5da633a 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -52,10 +52,15 @@
 
 
 def v2_only_combinations():
-  """Returns the default test combinations for v1 only tf.data tests."""
+  """Returns the default test combinations for v2 only tf.data tests."""
   return combinations.combine(tf_api_version=2, mode=["eager", "graph"])
 
 
+def v2_eager_only_combinations():
+  """Returns the default test combinations for v2 eager only tf.data tests."""
+  return combinations.combine(tf_api_version=2, mode="eager")
+
+
 class DatasetTestBase(test.TestCase):
   """Base class for dataset tests."""
 
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index f6f2da0..479c8d3 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -36,7 +36,6 @@
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import type_spec
 from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
 from tensorflow.python.training.saver import BaseSaverBuilder
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util import deprecation
@@ -656,11 +655,7 @@
   in eager mode and inside of tf.functions.
   """
 
-  def __init__(self,
-               dataset=None,
-               components=None,
-               element_spec=None,
-               job_token=None):
+  def __init__(self, dataset=None, components=None, element_spec=None):
     """Creates a new iterator from the given dataset.
 
     If `dataset` is not specified, the iterator will be created from the given
@@ -673,20 +668,17 @@
       components: Tensor components to construct the iterator from.
       element_spec: A nested structure of `TypeSpec` objects that
         represents the type specification of elements of the iterator.
-      job_token: A token to use for reading from a tf.data service job. Data
-        will be partitioned among all iterators using the same token. If `None`,
-        the iterator will not read from the tf.data service.
 
     Raises:
       ValueError: If `dataset` is not provided and either `components` or
         `element_spec` is not provided. Or `dataset` is provided and either
         `components` and `element_spec` is provided.
     """
+    super(OwnedIterator, self).__init__()
     error_message = ("Either `dataset` or both `components` and "
                      "`element_spec` need to be provided.")
 
     self._device = context.context().device_name
-    self._job_token = job_token
 
     if dataset is None:
       if (components is None or element_spec is None):
@@ -729,11 +721,7 @@
           gen_dataset_ops.anonymous_iterator_v2(
               output_types=self._flat_output_types,
               output_shapes=self._flat_output_shapes))
-      if self._job_token is None:
-        gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
-      else:
-        gen_experimental_dataset_ops.make_data_service_iterator(
-            ds_variant, self._job_token, self._iterator_resource)
+      gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
       # Delete the resource when this object is deleted
       self._resource_deleter = IteratorResourceDeleter(
           handle=self._iterator_resource,
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 800e6a8..2b03b9d 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -13,6 +13,7 @@
 py_library(
     name = "distribute_test_lib_pip",
     deps = [
+        ":all_reduce",
         ":combinations",
         ":multi_worker_test_base",
         ":single_loss_example",
@@ -89,7 +90,6 @@
     srcs = ["cross_device_utils.py"],
     srcs_version = "PY2AND3",
     deps = [
-        ":all_reduce",
         ":values",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:collective_ops",
@@ -141,6 +141,7 @@
     ],
     srcs_version = "PY2AND3",
     deps = [
+        ":all_reduce",
         ":cross_device_ops",
         ":distribute_lib",
         ":mirrored_strategy",
@@ -870,6 +871,7 @@
     srcs = ["multi_worker_test_base.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":multi_process_runner",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:distributed_framework_test_lib",
@@ -879,12 +881,22 @@
         "//tensorflow/python:session",
         "//tensorflow/python:training_lib",
         "//tensorflow/python:util",
+        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:remote",
         "//third_party/py/numpy",
     ],
 )
 
+tf_py_test(
+    name = "multi_worker_test_base_test",
+    srcs = ["multi_worker_test_base_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":multi_worker_test_base",
+    ],
+)
+
 cuda_py_test(
     name = "checkpoint_utils_test",
     size = "medium",
@@ -1781,16 +1793,15 @@
     srcs = ["parameter_server_strategy_v2.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":distribute_lib",
         ":parameter_server_strategy",
-        "//tensorflow/python:constant_op",
-        "//tensorflow/python:dtypes",
+        ":sharded_variable",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:util",
-        "//tensorflow/python:variables",
-        "//tensorflow/python/distribute:distribute_lib",
-        "//tensorflow/python/distribute:input_lib",
-        "//tensorflow/python/distribute:sharded_variable",
-        "//tensorflow/python/distribute:values",
+        "//tensorflow/python:partitioned_variables",
+        "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:tf_decorator",
+        "//tensorflow/python:variable_scope",
+        "@absl_py//absl/logging",
     ],
 )
 
@@ -1801,9 +1812,17 @@
     deps = [
         ":multi_worker_test_base",
         ":parameter_server_strategy_v2",
+        ":sharded_variable",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:init_ops_v2",
+        "//tensorflow/python:linalg_ops_impl",
+        "//tensorflow/python:partitioned_variables",
         "//tensorflow/python:training_server_lib",
         "//tensorflow/python:variables",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:remote",
         "//tensorflow/python/eager:test",
     ],
 )
diff --git a/tensorflow/python/distribute/client/BUILD b/tensorflow/python/distribute/client/BUILD
index 9574f32..db4b421 100644
--- a/tensorflow/python/distribute/client/BUILD
+++ b/tensorflow/python/distribute/client/BUILD
@@ -67,18 +67,23 @@
     shard_count = 14,
     tags = ["no_oss"],  # TODO(b/162119374)
     deps = [
+        ":client",
         ":parameter_server_client",
+        "//tensorflow/python:check_ops",
         "//tensorflow/python:constant_op",
         "//tensorflow/python:dtypes",
-        "//tensorflow/python:init_ops_v2",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:random_ops",
+        "//tensorflow/python:tensor_spec",
         "//tensorflow/python:training_server_lib",
         "//tensorflow/python:variables",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/distribute:multi_worker_test_base",
-        "//tensorflow/python/distribute:sharded_variable",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/eager:test",
+        "@absl_py//absl/logging",
     ],
 )
 
@@ -131,10 +136,6 @@
     name = "utils",
     srcs = ["utils.py"],
     srcs_version = "PY2AND3",
-    visibility = [
-        "//learning/tfx/users/apps_itemsuggest:__subpackages__",
-        "//tensorflow:internal",
-    ],
     deps = [
         "//tensorflow/python:training_server_lib",
     ],
diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py
index 90d50c3..7f3559c 100644
--- a/tensorflow/python/distribute/client/client.py
+++ b/tensorflow/python/distribute/client/client.py
@@ -26,11 +26,14 @@
 import enum
 import functools
 import os
+import re
 import sys
 import threading
 import weakref
 from absl import logging
 from six.moves import queue
+
+from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import parameter_server_strategy_v2
 from tensorflow.python.distribute.client import metric_utils
@@ -542,8 +545,9 @@
 class WorkerPreemptionHandler(object):
   """Handles worker preemptions."""
 
-  def __init__(self, server_def):
+  def __init__(self, server_def, cluster):
     self._server_def = server_def
+    self._cluster = cluster
     self._cluster_update_lock = threading.Lock()
     self._cluster_due_for_update = threading.Event()
     self._worker_up_cond = threading.Condition(self._cluster_update_lock)
@@ -577,6 +581,13 @@
     try:
       yield
     except errors.OpError as e:
+      # If the error is due to temporary connectivity issues between worker and
+      # ps, put back closure, ignore error and do not mark worker as failure.
+      if self._cluster._record_and_ignore_transient_ps_failure(e):  # pylint: disable=protected-access
+        if on_failure_fn:
+          on_failure_fn()
+        return
+
       self._validate_preemption_failure(e)
       logging.error("Worker %s failed with error: %s", worker_device_name, e)
       if on_failure_fn:
@@ -766,7 +777,6 @@
       device_filters.set_device_filters(
           "ps", i, ["/job:worker", "/job:%s" % client_name])
 
-    context.context().mirroring_policy = context.MIRRORING_ALL
     # Allow at most one outstanding RPC for each worker at a certain time. This
     # is to simplify worker failure handling in the runtime
     os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"
@@ -775,8 +785,25 @@
                               protocol=cluster_resolver.rpc_layer,
                               cluster_device_filters=device_filters)
 
+    # Ignore PS failures reported by workers due to transient connection errors.
+    # Transient connectivity issues between workers and PS are relayed by the
+    # workers to the client, leading the client to believe that there are PS
+    # failures. The difference between transient vs. permanent PS failure is the
+    # number of reports from the workers. When this env var is set to a positive
+    # integer K, the client ignores up to K reports of a failed PS task. I.e.,
+    # only when there are more than K trials of executing closures fail due to
+    # errors from the same PS instance do we consider the PS instance encounters
+    # a failure.
+    # TODO(b/164279603): Remove this workaround when the underlying connectivity
+    # issue in gRPC server is resolved.
+    self._transient_ps_failures_threshold = int(os.environ.get(
+        "TF_CLIENT_IGNORE_TRANSIENT_PS_FAILURES", 3))
+    self._potential_ps_failures_lock = threading.Lock()
+    self._potential_ps_failures_count = [0] * self._num_ps
+
     self._closure_queue = _CoordinatedClosureQueue()
-    self.failure_handler = WorkerPreemptionHandler(context.get_server_def())
+    self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
+                                                   self)
     worker_device_strings = [
         "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
     ]
@@ -784,6 +811,22 @@
         Worker(i, w, self) for i, w in enumerate(worker_device_strings)
     ]
 
+  def _record_and_ignore_transient_ps_failure(self, e):
+    """Records potential PS failures and return if failure should be ignored."""
+    if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
+      return False
+
+    ps_tasks = _extract_failed_ps_instances(str(e))
+    with self._potential_ps_failures_lock:
+      for t in ps_tasks:
+        self._potential_ps_failures_count[t] += 1
+        # The number of UnavailableError encountered on this PS task exceeds the
+        # maximum number of ignored error
+        if (self._potential_ps_failures_count[t] >=
+            self._transient_ps_failures_threshold):
+          return False
+    return True
+
   def schedule(self, function, args, kwargs):
     """Schedules `function` to be dispatched to a worker for execution.
 
@@ -822,9 +865,7 @@
   """An object to schedule and orchestrate remote function execution.
 
   A `Client` object represents a program used to create dataset, schedule
-  functions to be executed, and fetch the results of the functions. Operations
-  that will involve other tasks in the cluster, such as variable creation,
-  reading variables etc., should be performed within `client.context()`.
+  functions to be executed, and fetch the results of the functions.
 
   Currently, `Client` is not supported to be used in a standalone manner.
   It should be used in conjunction with `ParameterServerStrategyV2`. The
@@ -855,36 +896,9 @@
     self._strategy = strategy
     self.cluster = Cluster(strategy._cluster_resolver)
 
-  @contextlib.contextmanager
-  def context(self):
-    """Context manager under which client distribution is in effect.
-
-    All distribution related methods using this `Client`, including those that
-    create and update variables, should be used within this context. This
-    context manager handles cluster fault tolerance in remote function
-    execution.
-
-    The context manager calls `join` automatically when exiting successfully.
-
-    Entering `Client.context` also enters the underlying strategy's scope, and
-    this means that `tf.distribute.get_strategy()` will return the strategy
-    object being used.
-
-    Yields:
-      Nothing.
-    """
-    with self._strategy.scope(), self._handle_parameter_server_failure():
-      yield
-    self.join()
-
-  @contextlib.contextmanager
-  def experimental_variable_partitioning_scope(self):
-    with self._strategy.experimental_variable_partitioning_scope():
-      yield
-
-  (experimental_variable_partitioning_scope.__doc__) = (
-      parameter_server_strategy_v2.ParameterServerStrategyV2
-      .experimental_variable_partitioning_scope.__doc__)
+  @property
+  def strategy(self):
+    return self._strategy
 
   def schedule(self, fn, args=None, kwargs=None):
     """Schedules `fn` to be dispatched to a worker for execution asynchronously.
@@ -939,7 +953,9 @@
         scheduled function since the last time an error was thrown or since
         the beginning of the program.
     """
-    with self._translate_parameter_server_failure():
+    # Slot variables are usually created during function tracing time; thus
+    # `schedule` needs to be called within the `strategy.scope()`.
+    with self.strategy.scope(), _translate_parameter_server_failure():
       return self.cluster.schedule(fn, args=args, kwargs=kwargs)
 
   def join(self):
@@ -961,9 +977,7 @@
         scheduled function since the last time an error was thrown or since
         the beginning of the program.
     """
-    # TODO(b/159486639): Update the docs once we can cancel the functions being
-    # executed on workers, that when `join` returns, the system is stabilized.
-    with self._translate_parameter_server_failure():
+    with _translate_parameter_server_failure():
       self.cluster.join()
 
   def done(self):
@@ -1062,31 +1076,32 @@
       return (result,)
     return result
 
-  # pylint: disable=missing-function-docstring
-  @contextlib.contextmanager
-  def _translate_parameter_server_failure(self):
-    try:
-      yield
-    except Exception as e:  # pylint: disable=broad-except
-      if _is_ps_failure(e):
-        logging.exception("Encountered parameter server failures!")
-        raise ParameterServerFailureError(e)
-      else:
-        raise
 
-  # pylint: disable=missing-function-docstring
-  @contextlib.contextmanager
-  def _handle_parameter_server_failure(self):
-    try:
-      with self._translate_parameter_server_failure():
-        yield
-    except ParameterServerFailureError as e:  # pylint: disable=broad-except
-      restart_exit_code = os.environ.get(
-          "TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE", None)
-      if restart_exit_code is not None:
-        sys.exit(int(restart_exit_code))
-      else:
-        raise
+# pylint: disable=missing-function-docstring
+@contextlib.contextmanager
+def _translate_parameter_server_failure():
+  try:
+    yield
+  except Exception as e:  # pylint: disable=broad-except
+    if _is_ps_failure(e):
+      raise ParameterServerFailureError(e)
+    else:
+      raise
+
+
+# pylint: disable=missing-function-docstring
+@contextlib.contextmanager
+def handle_parameter_server_failure():
+  try:
+    with _translate_parameter_server_failure():
+      yield
+  except ParameterServerFailureError as e:  # pylint: disable=broad-except
+    restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
+                                       None)
+    if restart_exit_code is not None:
+      sys.exit(int(restart_exit_code))
+    else:
+      raise
 
 
 class _PerWorkerDistributedDataset(object):
@@ -1132,15 +1147,12 @@
     per_worker_iterator = self._client._create_per_worker_resources(
         _create_per_worker_iterator)
 
-    # Create an iterator, so the consumer function of this iterator can start
-    # tracing using this iterator without needing to wait for the completion of
-    # the iterater creation. Note: the iterator shouldn't use memory until it is
-    # consumed.
-    # TODO(b/154675763): get rid of this workaround once we can make input_fn a
-    # tf.function.
-    iterator = _create_per_worker_iterator()
+    # Setting type_spec of each RemoteValue so that functions taking these
+    # RemoteValues as inputs can be traced.
     for iterator_remote_value in per_worker_iterator._values:
-      iterator_remote_value._set_type_spec(iterator._type_spec)
+      iterator_remote_value._set_type_spec(
+          iterator_ops.IteratorSpec(
+              self._dataset_fn.structured_outputs.element_spec))
     return _PerWorkerDistributedIterator(per_worker_iterator._values)
 
   @property
@@ -1162,6 +1174,12 @@
                               "is not supported right now.")
 
 
+def _extract_failed_ps_instances(err_msg):
+  """Return a set of potentially failing ps instances from error message."""
+  tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg)
+  return set(int(t.split(":")[-1]) for t in tasks)
+
+
 def _is_ps_failure(error):
   """Whether the error is considered a parameter server failure."""
   if (_RPC_ERROR_FROM_PS in str(error) or
diff --git a/tensorflow/python/distribute/client/client_test.py b/tensorflow/python/distribute/client/client_test.py
index 9698d6c..3ea3e46 100644
--- a/tensorflow/python/distribute/client/client_test.py
+++ b/tensorflow/python/distribute/client/client_test.py
@@ -19,6 +19,8 @@
 from __future__ import print_function
 
 import collections
+import platform
+import sys
 import threading
 import time
 from absl import logging
@@ -137,6 +139,10 @@
     coord.join([t])
 
   def testWaitRaiseErrorAfterMarkFailure(self):
+    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
+      # TODO(b/165013260): Fix this
+      self.skipTest('Test is currently broken on Windows with Python 3.8')
+
     closure_queue = client._CoordinatedClosureQueue()
     closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
     closure = closure_queue.get()
@@ -183,6 +189,10 @@
     return closure_queue, closure1, closure2
 
   def testPutRaiseError(self):
+    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
+      # TODO(b/165013260): Fix this
+      self.skipTest('Test is currently broken on Windows with Python 3.8')
+
     closure_queue, _, closure2 = self._put_two_closures_and_get_one()
 
     closure_queue.mark_failed(ValueError())
@@ -202,6 +212,10 @@
     closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
 
   def testWaitRaiseError(self):
+    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
+      # TODO(b/165013260): Fix this
+      self.skipTest('Test is currently broken on Windows with Python 3.8')
+
     closure_queue, _, closure2 = self._put_two_closures_and_get_one()
 
     closure_queue.mark_failed(ValueError())
@@ -220,6 +234,10 @@
     closure_queue.wait()
 
   def testDoneRaiseError(self):
+    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
+      # TODO(b/165013260): Fix this
+      self.skipTest('Test is currently broken on Windows with Python 3.8')
+
     closure_queue, _, _ = self._put_two_closures_and_get_one()
 
     self.assertFalse(closure_queue.done())
@@ -236,6 +254,10 @@
       closure_queue.mark_failed(e)
 
   def _test_cancel_closure_when_error(self, call_wait):
+    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
+      # TODO(b/165013260): Fix this
+      self.skipTest('Test is currently broken on Windows with Python 3.8')
+
     closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
     closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
     closure_queue.get()
@@ -306,6 +328,10 @@
     self._test_cancel_closure_when_error(call_wait=False)
 
   def testStateIsRestoredAfterJoinIsCalled(self):
+    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
+      # TODO(b/165013260): Fix this
+      self.skipTest('Test is currently broken on Windows with Python 3.8')
+
     closure_queue, _, _ = self._put_two_closures_and_get_one()
     self.assertEqual(closure_queue._inflight_closure_count, 1)
     closure_queue.mark_failed(ValueError('test error'))
diff --git a/tensorflow/python/distribute/client/parameter_server_client.py b/tensorflow/python/distribute/client/parameter_server_client.py
index 8236c24..de8e45a 100644
--- a/tensorflow/python/distribute/client/parameter_server_client.py
+++ b/tensorflow/python/distribute/client/parameter_server_client.py
@@ -49,7 +49,7 @@
 
   """
 
-  def __init__(self, cluster_resolver):
+  def __init__(self, cluster_resolver, variable_partitioner=None):
     super(ParameterServerClient, self).__init__(
         parameter_server_strategy_v2.ParameterServerStrategyV2(
-            cluster_resolver))
+            cluster_resolver, variable_partitioner))
diff --git a/tensorflow/python/distribute/client/parameter_server_client_test.py b/tensorflow/python/distribute/client/parameter_server_client_test.py
index 5edf7ba..a438345 100644
--- a/tensorflow/python/distribute/client/parameter_server_client_test.py
+++ b/tensorflow/python/distribute/client/parameter_server_client_test.py
@@ -26,7 +26,6 @@
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import multi_worker_test_base
-from tensorflow.python.distribute import sharded_variable
 from tensorflow.python.distribute.client import client
 from tensorflow.python.distribute.client import parameter_server_client
 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
@@ -37,7 +36,6 @@
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import init_ops_v2
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables
@@ -107,7 +105,7 @@
 
   def testBasic(self):
     self.client._strategy.extended._variable_count = 0
-    with self.client.context():
+    with self.client.strategy.scope():
       v1 = variables.Variable(initial_value=0.0)
       v2 = variables.Variable(initial_value=1.0)
     self.assertEqual(self.client._strategy.extended._variable_count, 2)
@@ -141,7 +139,7 @@
     def input_fn():
       return dataset_ops.DatasetV2.range(1, 2)
 
-    with self.client.context():
+    with self.client.strategy.scope():
       v = variables.Variable(initial_value=0, dtype=dtypes.int64)
 
     @def_function.function
@@ -165,7 +163,7 @@
     def input_fn():
       return dataset_ops.DatasetV2.from_tensor_slices([2] * 10)
 
-    with self.client.context():
+    with self.client.strategy.scope():
       v = variables.Variable(initial_value=0, dtype=dtypes.int32)
 
     # TODO(yuefengz): the following tf.function has a return value which is None
@@ -232,89 +230,6 @@
     with self.assertRaises(ValueError):
       self.client.create_per_worker_dataset(input_fn)
 
-
-class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
-  """Test basic functionality works with explicit maximum closure queue size.
-
-  Execute the same set of test cases as in ParameterServerClientTest, with an
-  explicit size limit for the closure queue. Note that even when the queue size
-  is set to infinite, there is still a maximum practical size (depends on host
-  memory limit) that might cause the queue.put operations to be blocking when
-  scheduling a large number of closures on a big cluster. These tests make sure
-  that the client does not run into deadlocks in such scenario.
-  """
-
-  @classmethod
-  def setUpClass(cls):
-    super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
-    client._CLOSURE_QUEUE_MAX_SIZE = 2
-    cls.client = make_client(num_workers=3, num_ps=2)
-
-
-class VariablePartitioningScopeTest(test.TestCase):
-
-  @classmethod
-  def setUpClass(cls):
-    super(VariablePartitioningScopeTest, cls).setUpClass()
-    cls.client = make_client(num_workers=3, num_ps=2)
-
-  def testBasic(self):
-    with self.client.context():
-      with self.client.experimental_variable_partitioning_scope():
-        init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
-        v1 = variables.Variable(
-            initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
-            shape=(5, 2),
-            dtype=dtypes.int64)
-
-        init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
-        v2 = variables.Variable(
-            initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
-            shape=(6, 1),
-            dtype=dtypes.int64)
-
-    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
-    self.assertLen(v1.variables, 2)
-    self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
-    self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
-    self.assertAllEqual(v1.variables[0].read_value().numpy(),
-                        [[0, 1], [2, 3], [4, 5]])
-    self.assertAllEqual(v1.variables[1].read_value().numpy(), [[6, 7], [8, 9]])
-
-    self.assertIsInstance(v2, sharded_variable.ShardedVariable)
-    self.assertLen(v2.variables, 2)
-    self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
-    self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
-    self.assertAllEqual(v2.variables[0].read_value().numpy(), [[0], [1], [2]])
-    self.assertAllEqual(v2.variables[1].read_value().numpy(), [[3], [4], [5]])
-
-  def testSurplusPS(self):
-    with self.client.context():
-      with self.client.experimental_variable_partitioning_scope():
-        initializer = init_ops_v2.Constant([0])
-
-        v = variables.Variable(
-            initial_value=lambda: initializer(shape=(1,), dtype=dtypes.int64),
-            shape=(1,),
-            dtype=dtypes.int64)
-
-    self.assertIsInstance(v, sharded_variable.ShardedVariable)
-    self.assertLen(v.variables, 1)
-    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
-    self.assertAllEqual(v.variables[0].read_value().numpy(), [0])
-
-  def testInvalidArgument(self):
-    with self.assertRaisesRegex(ValueError, "initial_value"):
-      with self.client.experimental_variable_partitioning_scope():
-        variables.Variable(initial_value=[0, 1, 2], shape=(3,))
-
-    with self.assertRaisesRegex(ValueError, "shape"):
-      with self.client.experimental_variable_partitioning_scope():
-        initializer = init_ops_v2.Constant([0, 1, 2])
-        variables.Variable(
-            initial_value=lambda: initializer(shape=(3,), dtype=dtypes.int64),
-            dtype=dtypes.int64)
-
   def testPerWorkerValue(self):
     var_shape = tuple()
     var_dtype = dtypes.float32
@@ -350,6 +265,24 @@
     self.assertEqual(var_sum, 10.0)
 
 
+class LimitedClosureQueueSizeBasicTest(ParameterServerClientTest):
+  """Test basic functionality works with explicit maximum closure queue size.
+
+  Execute the same set of test cases as in ParameterServerClientTest, with an
+  explicit size limit for the closure queue. Note that even when the queue size
+  is set to infinite, there is still a maximum practical size (depends on host
+  memory limit) that might cause the queue.put operations to be blocking when
+  scheduling a large number of closures on a big cluster. These tests make sure
+  that the client does not run into deadlocks in such scenario.
+  """
+
+  @classmethod
+  def setUpClass(cls):
+    super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
+    client._CLOSURE_QUEUE_MAX_SIZE = 2
+    cls.client = make_client(num_workers=3, num_ps=2)
+
+
 class ErrorReportingTest(TestCaseWithErrorReportingThread):
 
   @classmethod
@@ -357,7 +290,7 @@
     super(ErrorReportingTest, cls).setUpClass()
     cls.client = make_client(num_workers=3, num_ps=2)
 
-    with cls.client.context():
+    with cls.client.strategy.scope():
       cls.iteration = variables.Variable(initial_value=0.0)
 
   @def_function.function
@@ -476,7 +409,7 @@
     client._CLOSURE_QUEUE_MAX_SIZE = 2
     cls.client = make_client(num_workers=3, num_ps=2)
 
-    with cls.client.context():
+    with cls.client.strategy.scope():
       cls.iteration = variables.Variable(initial_value=0.0)
 
 
diff --git a/tensorflow/python/distribute/client/utils.py b/tensorflow/python/distribute/client/utils.py
index 6c59557..51d8263 100644
--- a/tensorflow/python/distribute/client/utils.py
+++ b/tensorflow/python/distribute/client/utils.py
@@ -28,8 +28,11 @@
   """Start a server and block the process from exiting."""
   # This function is for multi-processing test or users who would like to have
   # every job run the same binary for simplicity.
-  assert (cluster_resolver.task_type == 'worker' or
-          cluster_resolver.task_type == 'ps')
+  if not (cluster_resolver.task_type == 'worker' or
+          cluster_resolver.task_type == 'ps'):
+    raise ValueError('Unexpected task_type to start a server: {}'.format(
+        cluster_resolver.task_type))
+
   server = server_lib.Server(
       cluster_resolver.cluster_spec().as_cluster_def(),
       job_name=cluster_resolver.task_type,
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index 6a133c7..49b6a93 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -25,6 +25,7 @@
 
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.core.protobuf import tensorflow_server_pb2
+from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
 from tensorflow.python.distribute import cross_device_utils
 from tensorflow.python.distribute import device_util
@@ -39,7 +40,6 @@
 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
 from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
 from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -188,6 +188,9 @@
   _check_health_interval = 30
   # Timeout in seconds for the first check health. The first check health needs
   # to wait for cluster, which may make a longer time.
+  #
+  # TODO(b/151232436): now the inital barrier may hang in a rare case, so we
+  # need a finite timeout.
   _check_health_initial_timeout = 1200
 
   def __init__(self,
@@ -629,55 +632,20 @@
         destinations=destinations,
         experimental_hints=experimental_hints)
 
-  def _check_health(self, device, group_key, instance_key):
-    first = True
-    # We need to use a large enough value so that the all-reduce forms a
-    # complete RING. In RING implementation, when value is too small, the
-    # all-reduce may degrade into broadcasts. This means that some worker
-    # failure may not be detected.
-    value = array_ops.ones((32, 32), dtype=dtypes.float32)
+  def _check_health(self):
     while True:
       if self._check_health_thread_should_stop.is_set():
         return
-      timeout = None
-      if first:
-        # For the first check health we set timeout since it may need to do
-        # group resolution, which may hang if the cluster is never healthy.
-        timeout = self._check_health_initial_timeout
-        first = False
       try:
-        # We use an dummy all-reduce as a way to check the health of a cluster.
-        # For RING it should be able to detect failed workers in the cluster if
-        # the values are large enough.
-        #
-        # We're not using CrossDeviceOps because we need to run it with
-        # pre-allocated group and instance keys.
-        #
-        # TODO(b/151232436): Replace the reduce with a check health op once we
-        # add that.
-        with ops.device(device):
-          collective_ops.all_reduce(
-              value,
-              group_size=self._num_workers,
-              group_key=group_key,
-              instance_key=instance_key,
-              merge_op="Add",
-              final_op="Id",
-              subdiv_offsets=[0],
-              communication_hint="ring",
-              timeout=timeout)
-          if context.is_async():
-            context.async_wait()
-      except (errors.UnavailableError, errors.DeadlineExceededError,
-              errors.FailedPreconditionError, errors.CancelledError) as e:
+        for job in self._cluster_spec.jobs:
+          for task_id in range(self._cluster_spec.num_tasks(job)):
+            context.context().check_collective_ops_peer_health(
+                "/job:{}/replica:0/task:{}".format(job, task_id))
+      except (errors.UnavailableError, errors.FailedPreconditionError) as e:
         # TODO(b/151232436): Always raise UnavailableError when a peer fails.
         # Now there could be many kinds of errors:
         # - Unavailable: when the peer is not reachable, e.g. it's down.
         # - FailedPrecondition: when the peer has restarted.
-        # - DeadlineExceeded: when the first check health exceeds the deadline,
-        #   e.g. the peers take too long to be ready.
-        # - Cancelled: when failures in organic collectives aborts first,
-        #   outgoing RPCs may be aborted with Cancelled.
         logging.error("Cluster check alive failed, aborting collectives")
         context.context().abort_collective_ops(
             errors.UNAVAILABLE, "cluster check alive failed: %s" % e)
@@ -689,20 +657,32 @@
       time.sleep(self._check_health_interval)
 
   def _start_check_health_thread(self):
-    # Allocate group and instance key before starting the thread to avoid
-    # indeterminism. There can only be one thread that assigns group keys and
-    # instance keys, otherwise different workers may end up with unmatched keys
-    # since execution order between threads are arbitrary.
-    device = device_util.canonicalize(self._worker_device)
-    group_key = self._collective_keys.get_group_key([device])
-    instance_key = self._collective_keys.get_op_instance_key()
+    # Use a dummy all-reduce as a barrier to wait for all workers to be up,
+    # otherwise the check health may fail immediately.
+    #
+    # TODO(b/151232436): change to an explicit barrier if we have it.
+    dummy_value = ops.convert_to_tensor([])
+    logging.info("Waiting for the cluster, timeout = %d",
+                 self._check_health_initial_timeout)
+    try:
+      self._host_cross_device_ops.reduce(
+          reduce_util.ReduceOp.SUM,
+          dummy_value,
+          dummy_value,
+          experimental_hints=collective_util.Hints(
+              timeout_seconds=self._check_health_initial_timeout))
+      if context.is_async():
+        context.async_wait()
+    except errors.DeadlineExceededError:
+      raise RuntimeError(
+          "Timeout waiting for the cluster, timeout is %d seconds" %
+          self._check_health_initial_timeout)
     self._check_health_thread_should_stop = threading.Event()
     # Start the thread as daemon to avoid it blocking the program from exiting.
     # We try best to shutdown the thread but __del__ is not guaranteed to be
     # called when program exists.
     self._check_health_thread = threading.Thread(
         target=self._check_health,
-        args=(device, group_key, instance_key),
         daemon=True)
     self._check_health_thread.start()
 
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index ed3e2d5..9c554bc 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -81,7 +81,7 @@
     reduce_op, value, destinations, num_replicas_in_graph):
   """Reduce a non-DistributedValue `value` to `destinations`."""
   if isinstance(value, value_lib.DistributedValues):
-    raise ValueError("You are passing a `DistributedValue` to "
+    raise ValueError("You are passing a `DistributedValues` to "
                      "`reduce_non_distributed_value`, which is not allowed.")
 
   # If the same value is present on all replicas then the PerReplica value will
@@ -216,7 +216,18 @@
 
 @tf_export("distribute.CrossDeviceOps")
 class CrossDeviceOps(object):
-  """Base class for cross-device reduction and broadcasting algorithms."""
+  """Base class for cross-device reduction and broadcasting algorithms.
+
+  The main purpose of this class is to be passed to
+  `tf.distribute.MirroredStrategy` in order to choose among different cross
+  device communication implementations. Prefer using the methods of
+  `tf.distribute.Strategy` instead of the ones of this class.
+
+  Implementations:
+  * `tf.distribute.ReductionToOneDevice`
+  * `tf.distribute.NcclAllReduce`
+  * `tf.distribute.HierarchicalCopyAllReduce`
+  """
 
   def __init__(self):
     pass
@@ -233,24 +244,30 @@
              experimental_hints=None):
     """Reduce `per_replica_value` to `destinations`.
 
-    It runs the reduction operation defined by `reduce_op` and put the
-    result on `destinations`.
+    See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
+    the cross-replica context.
 
     Args:
-      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
-        per_replica_value will be reduced.
-      per_replica_value: A `tf.distribute.DistributedValues` object or a tensor
-        with device set.
-      destinations: the reduction destinations.
-      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
-        to perform collective operations.
+      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
+        combined.
+      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
+        like object.
+      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
+        `tf.Tensor` alike object, or a device string. It specifies the devices
+        to reduce to. To perform an all-reduce, pass the same to `value` and
+        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
+        to the devices of that variable, and this method doesn't update the
+        variable.
+      experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
+        `tf.distribute.experimental.CollectiveHints` for details.
 
     Returns:
-      a Mirrored object.
+      A `tf.Tensor` or `tf.distribute.DistributedValues`.
 
     Raises:
-      ValueError: if per_replica_value can't be converted to a PerReplica
-        object or if destinations aren't strings, Variables or DistributedValues
+      ValueError: if per_replica_value can't be converted to a
+        `tf.distribute.DistributedValues` or if destinations is not a string,
+        `tf.Variable` or `tf.distribute.DistributedValues`.
     """
     if not isinstance(per_replica_value, value_lib.DistributedValues):
       per_replica_value = _make_tensor_into_per_replica(per_replica_value)
@@ -274,28 +291,26 @@
                    reduce_op,
                    value_destination_pairs,
                    experimental_hints=None):
-    """Reduce PerReplica objects in a batch.
+    """Reduce values to destinations in batches.
 
-    Reduce each first element in `value_destination_pairs` to each second
-    element which indicates the destinations.
-
-    This can be faster than multiple individual `reduce`s because we can
-    fuse several tensors into one or multiple packs before reduction.
+    See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
+    called in the cross-replica context.
 
     Args:
-      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
-        `per_replica_value` will be reduced.
-      value_destination_pairs: A list or a tuple of PerReplica objects (or
-        tensors with device set if there is one device) and destinations.
-      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
-        to perform collective operations.
+      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
+        combined.
+      value_destination_pairs: a sequence of (value, destinations) pairs. See
+        `tf.distribute.CrossDeviceOps.reduce` for descriptions.
+      experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
+        `tf.distribute.experimental.CollectiveHints` for details.
 
     Returns:
-      a list of Mirrored objects.
+      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
+      in `value_destination_pairs`.
 
     Raises:
       ValueError: if `value_destination_pairs` is not an iterable of
-        tuples of PerReplica objects and destinations.
+        tuples of `tf.distribute.DistributedValues` and destinations.
     """
     # TODO(yuefengz): if destinations are different, split into several
     # `_batch_reduce` invocations.
@@ -323,14 +338,20 @@
                                             experimental_hints)
 
   def broadcast(self, tensor, destinations):
-    """Broadcast the `tensor` to destinations.
+    """Broadcast `tensor` to `destinations`.
+
+    This can only be called in the cross-replica context.
 
     Args:
-      tensor: the tensor to broadcast.
-      destinations: the broadcast destinations.
+      tensor: a `tf.Tensor` like object. The value to broadcast.
+      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
+        `tf.Tensor` alike object, or a device string. It specifies the devices
+        to broadcast to. Note that if it's a `tf.Variable`, the value is
+        broadcasted to the devices of that variable, this method doesn't update
+        the variable.
 
     Returns:
-      a Mirrored object.
+      A `tf.Tensor` or `tf.distribute.DistributedValues`.
     """
     validate_destinations(destinations)
     return self.broadcast_implementation(tensor, destinations)
@@ -338,27 +359,31 @@
   @doc_controls.for_subclass_implementers
   def reduce_implementation(self, reduce_op, per_replica_value, destinations,
                             experimental_hints):
-    """The implementation of reduce of `per_replica_value` to `destinations`.
+    """Implementation of `reduce`.
 
     Overriding this method is useful for subclass implementers.
 
-    It runs the reduction operation defined by `reduce_op` and put the
-    result on `destinations`.
-
     Args:
-      reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how
-        per_replica_value will be reduced.
-      per_replica_value: A PerReplica object or a tensor with device set.
-      destinations: the reduction destinations.
-      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
-        to perform collective operations.
+      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
+        combined.
+      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
+        like object.
+      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
+        `tf.Tensor` alike object, or a device string. It specifies the devices
+        to reduce to. To perform an all-reduce, pass the same to `value` and
+        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
+        to the devices of that variable, this method doesn't update the
+        variable.
+      experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
+        `tf.distribute.experimental.CollectiveHints` for details.
 
     Returns:
-      a Mirrored object.
+      A `tf.Tensor` or `tf.distribute.DistributedValues`.
 
     Raises:
-      ValueError: if per_replica_value can't be converted to a PerReplica
-        object.
+      ValueError: if per_replica_value can't be converted to a
+        `tf.distribute.DistributedValues` or if destinations is not a string,
+        `tf.Variable` or `tf.distribute.DistributedValues`.
     """
     raise NotImplementedError(
         "_reduce method must be implemented in descendants.")
@@ -366,27 +391,25 @@
   @doc_controls.for_subclass_implementers
   def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
                                   experimental_hints):
-    """Implementation of reduce PerReplica objects in a batch.
+    """Implementation of `batch_reduce`.
 
     Overriding this method is useful for subclass implementers.
 
-    Reduce each first element in `value_destination_pairs` to each second
-    element which indicates the destinations.
-
     Args:
-      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
-        per_replica_value will be reduced.
-      value_destination_pairs: An iterable of tuples of PerReplica objects
-        (or tensors with device set if there is one device) and destinations.
-      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
+        combined.
+      value_destination_pairs: a sequence of (value, destinations) pairs. See
+        `reduce` for descriptions.
+      experimental_hints: a `tf.distribute.experimental.CollectiveHints`. Hints
         to perform collective operations.
 
     Returns:
-      a list of Mirrored objects.
+      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
+      in `value_destination_pairs`.
 
     Raises:
       ValueError: if `value_destination_pairs` is not an iterable of
-        tuples of PerReplica objects and destinations
+        tuples of `tf.distribute.DistributedValues` and destinations.
     """
     raise NotImplementedError(
         "batch_reduce_implementation method must be implemented in descendants."
@@ -394,26 +417,36 @@
 
   @doc_controls.for_subclass_implementers
   def broadcast_implementation(self, tensor, destinations):
-    """Implementation of broadcast the `tensor` to destinations.
+    """Implementation of `broadcast`.
 
     Args:
-      tensor: the tensor to broadcast.
-      destinations: the broadcast destinations.
+      tensor: a `tf.Tensor` like object. The value to broadcast.
+      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
+        `tf.Tensor` alike object, or a device string. It specifies the devices
+        to broadcast to.
+        `destinations`. Note that if it's a `tf.Variable`, the value is
+        broadcasted to the devices of that variable, this method doesn't update
+        the variable.
 
     Returns:
-      a Mirrored object.
+      A `tf.Tensor` or `tf.distribute.DistributedValues`.
     """
     return simple_broadcast(tensor, destinations, always_mirrored=True)
 
 
 @tf_export("distribute.ReductionToOneDevice")
 class ReductionToOneDevice(CrossDeviceOps):
-  """Always do reduction to one device first and then do broadcasting.
+  """A CrossDeviceOps implementation that copies values to one device to reduce.
 
-  Batch reduction is done by reduction on each element one by one.
+  This implementation always copies values to one device to reduce them, then
+  broadcast reduced values to the destinations. It doesn't support efficient
+  batching.
+
+  Here is how you can use `ReductionToOneDevice` in
+  `tf.distribute.MirroredStrategy`:
 
   ```
-    mirrored_strategy = tf.distribute.MirroredStrategy(
+    strategy = tf.distribute.MirroredStrategy(
       cross_device_ops=tf.distribute.ReductionToOneDevice())
   ```
   """
@@ -423,8 +456,8 @@
 
     Args:
       reduce_to_device: the intermediate device to reduce to. If None, reduce
-        to the first device in `destinations` of the `reduce()` method.
-      accumulation_fn: a function that does accumulation.  If None, then
+        to the first device in `destinations` of the `reduce` method.
+      accumulation_fn: a function that does accumulation.  If None,
         `tf.math.add_n` is used.
     """
     self.reduce_to_device = reduce_to_device
@@ -641,18 +674,24 @@
 
 
 class AllReduceCrossDeviceOps(CrossDeviceOps):
-  """Reduction using all-reduce."""
+  """All-reduce implementation of CrossDeviceOps.
+
+  It performs all-reduce when applicable using NCCL or hierarchical copy. For
+  the batch API, tensors will be repacked or aggregated for more efficient
+  cross-device transportation.
+
+  For reduces that are not all-reduce, it falls back to
+  `tf.distribute.ReductionToOneDevice`.
+  """
 
   def __init__(self, all_reduce_alg="nccl", num_packs=1):
-    """All-reduce implementation of CrossDeviceOps.
-
-    Before performing all-reduce, tensors will be packed for more efficient
-    cross-device transportation.
+    """Initializes the object.
 
     Args:
       all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
         "hierarchical_copy" are supported.
-      num_packs: If non-zero, pack values into `num_packs` splits.
+      num_packs: a non-negative integer. The number of packs to split values
+        into. If zero, no packing will be done.
     """
     self._all_reduce_alg = all_reduce_alg
     self._num_packs = num_packs
@@ -746,21 +785,32 @@
 
 @tf_export("distribute.NcclAllReduce")
 class NcclAllReduce(AllReduceCrossDeviceOps):
-  """Reduction using NCCL all-reduce."""
+  """NCCL all-reduce implementation of CrossDeviceOps.
+
+  It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
+  repacked or aggregated for more efficient cross-device transportation.
+
+  For reduces that are not all-reduce, it falls back to
+  `tf.distribute.ReductionToOneDevice`.
+
+  Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`:
+
+
+  ```
+    strategy = tf.distribute.MirroredStrategy(
+      cross_device_ops=tf.distribute.NcclAllReduce())
+  ```
+  """
 
   def __init__(self, num_packs=1):
-    """NCCL all-reduce implementation of CrossDeviceOps.
-
-    It uses Nvidia NCCL for all-reduce. Before performing all-reduce, tensors
-    will be repacked or aggregated for more efficient cross-device
-    transportation.
+    """Initializes the object.
 
     Args:
-      num_packs: values will be packed in this many splits.  `num_packs` should
-        be greater than or equals 0. When it is zero, no packing will be done.
+      num_packs: a non-negative integer. The number of packs to split values
+        into. If zero, no packing will be done.
 
     Raises:
-      ValueError if `num_packs` is negative.
+      ValueError: if `num_packs` is negative.
     """
     if num_packs < 0:
       raise ValueError(
@@ -772,23 +822,34 @@
 
 @tf_export("distribute.HierarchicalCopyAllReduce")
 class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
-  """Reduction using hierarchical copy all-reduce.
+  """Hierarchical copy all-reduce implementation of CrossDeviceOps.
 
   It reduces to one GPU along edges in some hierarchy and broadcasts back to
-  each GPU along the same path. Before performing all-reduce, tensors will be
-  repacked or aggregated for more efficient cross-device transportation.
+  each GPU along the same path. For the batch API, tensors will be repacked or
+  aggregated for more efficient cross-device transportation.
 
   This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
   that on DGX-1 machine. If you have different GPU inter-connections, it is
   likely that it would be slower than `tf.distribute.ReductionToOneDevice`.
+
+  For reduces that are not all-reduce, it falls back to
+  `tf.distribute.ReductionToOneDevice`.
+
+  Here is how you can use `HierarchicalCopyAllReduce` in
+  `tf.distribute.MirroredStrategy`:
+
+  ```
+    strategy = tf.distribute.MirroredStrategy(
+      cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
+  ```
   """
 
   def __init__(self, num_packs=1):
     """Initializes the object.
 
     Args:
-      num_packs: values will be packed in this many splits.  `num_packs` should
-        be greater than or equals 0. When it is zero, no packing will be done.
+      num_packs: a non-negative integer. The number of packs to split values
+        into. If zero, no packing will be done.
 
     Raises:
       ValueError if `num_packs` is negative.
@@ -802,117 +863,6 @@
         num_packs=num_packs)
 
 
-class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
-  """All-reduce algorithms for distributed TensorFlow."""
-
-  def __init__(self,
-               worker_devices,
-               num_gpus_per_worker,
-               all_reduce_spec=("pscpu/pscpu", 2, -1),
-               num_packs=0):
-    """Initialize the all-reduce algorithm.
-
-    Args:
-      worker_devices: a list of device strings for workers participating in
-        all-reduce.
-      num_gpus_per_worker: number of GPU devices per worker.
-      all_reduce_spec: a tuple or a named tuple or a list of tuples specifying
-        the all-reduce algorithm.
-        1. The first element of a tuple is the name of the all-reduce algorithm.
-        Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd",
-        "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with
-        a "/" are hierarchical, so two all-reduces are executed, the first one
-        aggregates tensors within a worker and the second aggregates across
-        workers.
-        2. The second element of a tuple is the number of shards when doing
-        all-reduce. Let's say its values is M, each tensor after packing will be
-        split into M shards and then M parallel all-reduces would be performed
-        before finally they are concatenated backed into a complete tensor.
-        3. The third element is the maximum size of tensors that will be
-        applicable for the algorithm specified by the first element. For
-        example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)],
-        tensors with size not larger than 1024 bytes will be applied a 2-shard
-        "nccl" all-reduce and other tensors will be applied a 2-shard
-        "pscpu/pscpu" algorithm. The third elements should be in increasing
-        order across tuples and end with -1 which indicates infinity.
-      num_packs: see AllReduceCrossDeviceOps.
-    """
-    self._worker_devices = worker_devices
-    self._num_gpus_per_worker = num_gpus_per_worker
-    super(MultiWorkerAllReduce, self).__init__(num_packs=num_packs)
-
-    def validate_and_complete_spec(spec):
-      """Validate and complete the all-reduce spec."""
-      # TODO(yuefengz): support namedtuple.
-      if not isinstance(spec, tuple):
-        raise ValueError(
-            "A tuple is expected for all-reduce spec: %r" % all_reduce_spec)
-      if not spec or len(spec) > 3:
-        raise ValueError(
-            "Too many elements in the all-reduce spec tuple: %r" % spec)
-      if len(spec) == 1:
-        return AllReduceSpecTuple(spec[0], 1, -1)
-      elif len(spec) == 2:
-        return AllReduceSpecTuple(spec[0], spec[1], -1)
-      else:
-        return AllReduceSpecTuple(*spec)
-
-    self._all_reduce_spec = []
-    if isinstance(all_reduce_spec, six.string_types):
-      self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1))
-    elif isinstance(all_reduce_spec, tuple):
-      self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec))
-    elif isinstance(all_reduce_spec, list):
-      self._all_reduce_spec = [
-          validate_and_complete_spec(spec) for spec in all_reduce_spec
-      ]
-
-  def _batch_all_reduce(self, reduce_op, per_replica_values):
-    """All-reduce algorithm in a batch."""
-    logging.log_first_n(
-        logging.INFO, "Distributed batch_all_reduce: %d all-reduces with "
-        "allreduce_spec = %r, num_packs = %d" %
-        (len(per_replica_values), self._all_reduce_spec, self._num_packs), 10)
-
-    device_grads = _group_value_by_device(per_replica_values)
-
-    # The all-reduce library requires fully defined shapes.
-    # TODO(yuefengz): when tensor sharding is not needed, static shapes are not
-    # required as well.
-    for device_grad in device_grads:
-      for grad, _ in device_grad:
-        if not grad.shape.is_fully_defined():
-          raise ValueError("Shape is unknown for node %r" % grad)
-
-    remaining_grads = device_grads
-    aggregated_grads = []
-    for spec_tuple in self._all_reduce_spec:
-      if spec_tuple.limit < 0:
-        this_grads = remaining_grads
-        remaining_grads = []
-      else:
-        (this_grads, remaining_grads) = cross_device_utils.split_grads_by_size(
-            spec_tuple.limit, remaining_grads)
-      if this_grads:
-        device_grad_packs, tensor_packer = _pack_tensors(
-            this_grads, self._num_packs)
-        range_agg_grads = cross_device_utils.sum_gradients_all_reduce(
-            self._worker_devices, device_grad_packs, len(self._worker_devices),
-            spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
-        range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
-
-        if not aggregated_grads:
-          aggregated_grads = range_agg_grads
-        else:
-          assert len(aggregated_grads) == len(range_agg_grads)
-          for i, range_agg_grad in enumerate(range_agg_grads):
-            aggregated_grads[i] += range_agg_grad
-    assert not remaining_grads
-
-    return _ungroup_and_make_mirrored(aggregated_grads, per_replica_values[0],
-                                      reduce_op)
-
-
 @tf_export("distribute.experimental.CollectiveCommunication")
 class CollectiveCommunication(enum.Enum):
   """Communication choices for CollectiveOps.
diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py
index 967de7d..557c601 100644
--- a/tensorflow/python/distribute/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/cross_device_ops_test.py
@@ -433,55 +433,6 @@
     self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0])
 
 
-class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
-                                    CrossDeviceOpsTestBase):
-
-  worker_devices = [
-      "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
-  ]
-  multi_worker_allreduce_combinations = combinations.combine(
-      cross_device_ops=[
-          combinations.NamedObject(
-              "MultiWorkerAllReduce",
-              cross_device_ops_lib.MultiWorkerAllReduce(worker_devices, 2,
-                                                        ("pscpu/pscpu", 2, -1),
-                                                        0)),
-          combinations.NamedObject(
-              "MultiWorkerAllReducePack",
-              cross_device_ops_lib.MultiWorkerAllReduce(worker_devices, 2,
-                                                        ("pscpu/pscpu", 2, -1),
-                                                        1)),
-          combinations.NamedObject(
-              "MultiWorkerAllReduceMultipleSpecs",
-              cross_device_ops_lib.MultiWorkerAllReduce(
-                  worker_devices, 2, [("pscpu/pscpu", 2, 100),
-                                      ("xring", 2, -1)], 0)),
-      ],
-      devices=[
-          [
-              "/job:worker/replica:0/task:0/device:CPU:0",
-              "/job:worker/replica:0/task:1/device:CPU:0"
-          ],
-          [
-              "/job:worker/replica:0/task:0/device:GPU:0",
-              "/job:worker/replica:0/task:1/device:GPU:0"
-          ],
-          [
-              "/job:worker/replica:0/task:0/device:GPU:0",
-              "/job:worker/replica:0/task:0/device:GPU:1",
-              "/job:worker/replica:0/task:1/device:GPU:0",
-              "/job:worker/replica:0/task:1/device:GPU:1"
-          ],
-      ],
-      mode=["graph"])
-
-  @combinations.generate(multi_worker_allreduce_combinations)
-  def testReductionAndBroadcast(self, cross_device_ops, devices):
-    # Mimic the default device of multi-worker strategies.
-    with ops.device("/job:worker/replica:0/task:0"):
-      self._testReductionAndBroadcast(cross_device_ops, devices)
-
-
 NUM_WORKERS = 3
 
 CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication
diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py
index a8d4d17..bf7b413 100644
--- a/tensorflow/python/distribute/cross_device_utils.py
+++ b/tensorflow/python/distribute/cross_device_utils.py
@@ -18,16 +18,13 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections as pycoll
 import copy
 import threading
 
-from tensorflow.python.distribute import all_reduce
 from tensorflow.python.distribute import values as value_lib
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.framework import device as pydev
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import collective_ops
@@ -171,65 +168,6 @@
     return (grad, v), None
 
 
-def group_device_names(devices, group_size):
-  """Group device names into groups of group_size.
-
-  Args:
-    devices: a list of canonical device strings.
-    group_size: integer which is equal to or greater than 1.
-
-  Returns:
-    list of lists of devices, where each inner list is group_size long,
-      and each device appears at least once in an inner list.  If
-      len(devices) % group_size == 0 then each device will appear exactly once.
-
-  Raises:
-    ValueError: if group_size > len(devices)
-  """
-  num_devices = len(devices)
-  if group_size > num_devices:
-    raise ValueError(
-        'only %d devices, but group_size=%d' % (num_devices, group_size))
-  num_groups = (
-      num_devices // group_size + (1 if (num_devices % group_size != 0) else 0))
-  groups = [[] for i in range(num_groups)]
-  for i in range(num_groups * group_size):
-    groups[i % num_groups].append(devices[i % num_devices])
-  return groups
-
-
-def split_grads_by_size(threshold_size, device_grads):
-  """Break gradients into two sets according to tensor size.
-
-  Args:
-    threshold_size: int size cutoff for small vs large tensor.
-    device_grads: List of lists of (gradient, variable) tuples.  The outer
-        list is over devices. The inner list is over individual gradients.
-
-  Returns:
-    small_grads: Subset of device_grads where shape is <= threshold_size
-       elements.
-    large_grads: Subset of device_grads where shape is > threshold_size
-       elements.
-  """
-  small_grads = []
-  large_grads = []
-  for dl in device_grads:
-    small_dl = []
-    large_dl = []
-    for (g, v) in dl:
-      tensor_size = g.get_shape().num_elements()
-      if tensor_size <= threshold_size:
-        small_dl.append([g, v])
-      else:
-        large_dl.append([g, v])
-    if small_dl:
-      small_grads.append(small_dl)
-    if large_dl:
-      large_grads.append(large_dl)
-  return small_grads, large_grads
-
-
 # TODO(yuefengz): use random key starts to avoid reusing keys?
 class CollectiveKeys(object):
   """Class that manages collective keys.
@@ -580,272 +518,6 @@
   return out_slices_list
 
 
-def sum_grad_and_var_all_reduce(grad_and_vars,
-                                num_workers,
-                                alg,
-                                gpu_indices,
-                                aux_devices=None,
-                                num_shards=1):
-  """Apply all-reduce algorithm over specified gradient tensors."""
-  with ops.name_scope('allreduce'):
-    # Note that each grad_and_vars looks like the following:
-    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
-    scaled_grads = [g for g, _ in grad_and_vars]
-    if alg == 'nccl':
-      summed_grads = nccl_ops.all_sum(scaled_grads)
-    elif alg == 'xring':
-      summed_grads = all_reduce.build_ring_all_reduce(
-          scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add)
-    elif alg == 'nccl/xring':
-      summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards,
-                                                     math_ops.add)
-    elif alg == 'nccl/rechd':
-      summed_grads = all_reduce.build_nccl_then_recursive_hd(
-          scaled_grads, math_ops.add)
-    elif alg == 'nccl/pscpu':
-      summed_grads = all_reduce.build_nccl_then_shuffle(
-          scaled_grads, aux_devices, math_ops.add, math_ops.add_n)
-    elif alg == 'pscpu/pscpu':
-      second_gather_devices = aux_devices[:num_shards]
-      summed_grads = all_reduce.build_shuffle_then_shuffle(
-          scaled_grads, aux_devices, second_gather_devices, math_ops.add_n)
-    elif alg in ['pscpu', 'psgpu']:
-      summed_grads = all_reduce.build_shuffle_all_reduce(
-          scaled_grads, aux_devices, math_ops.add_n)
-    else:
-      raise ValueError('unsupported all_reduce alg: ', alg)
-
-  result = []
-  for (_, v), g in zip(grad_and_vars, summed_grads):
-    result.append([g, v])
-  return result
-
-
-def sum_gradients_all_reduce(dev_prefixes, replica_grads, num_workers, alg,
-                             num_shards, gpu_indices):
-  """Apply all-reduce algorithm over specified gradient tensors.
-
-  Args:
-    dev_prefixes: list of prefix strings to use to generate PS device names.
-    replica_grads: the gradients to reduce.
-    num_workers: number of worker processes across entire job.
-    alg: the all-reduce algorithm to apply.
-    num_shards: alg-specific sharding factor.
-    gpu_indices: indices of local GPUs in order usable for ring-reduce.
-
-  Returns:
-    list of reduced tensors
-  """
-  alg_contains_shuffle = any(n in alg for n in ['pscpu', 'psgpu'])
-  is_hierarchical = '/' in alg
-  if 'pscpu' in alg:
-    aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes]
-  elif 'psgpu' in alg:
-    aux_devices = [
-        prefix + '/gpu:%d' % i
-        for i in range(len(gpu_indices))
-        for prefix in dev_prefixes
-    ]
-  else:
-    aux_devices = ['/job:localhost/cpu:0']
-  # Auxiliary devices for hierarchical all-reduces.
-  aux_device_groups = group_device_names(
-      aux_devices, num_shards if alg_contains_shuffle else 1)
-  group_index = 0
-  reduced_gv_list = []
-  for grad_and_vars in zip(*replica_grads):
-    reduced_gv_list.append(
-        sum_grad_and_var_all_reduce(
-            grad_and_vars, num_workers, alg, gpu_indices, aux_devices
-            if is_hierarchical else aux_device_groups[group_index], num_shards))
-    group_index = (group_index + 1) % len(aux_device_groups)
-  new_replica_grads = [list(x) for x in zip(*reduced_gv_list)]
-  return new_replica_grads
-
-
-def extract_ranges(index_list, range_size_limit=32):
-  """Extract consecutive ranges and singles from index_list.
-
-  Args:
-    index_list: List of monotone increasing non-negative integers.
-    range_size_limit: Largest size range to return.  If a larger
-      consecutive range exists, it will be returned as multiple
-      ranges.
-
-  Returns:
-    (ranges, singles) where ranges is a list of [first, last] pairs of
-      consecutive elements in index_list, and singles is all of the
-      other elements, in original order.
-  """
-  if not index_list:
-    return [], []
-  first = index_list[0]
-  last = first
-  ranges = []
-  singles = []
-  for i in index_list[1:]:
-    if i == last + 1 and (last - first) <= range_size_limit:
-      last = i
-    else:
-      if last > first:
-        ranges.append([first, last])
-      else:
-        singles.append(first)
-      first = i
-      last = i
-  if last > first:
-    ranges.append([first, last])
-  else:
-    singles.append(first)
-  return ranges, singles
-
-
-GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes')
-
-
-def pack_range(key, packing, grad_vars, rng):
-  """Form the concatenation of a specified range of gradient tensors.
-
-  Args:
-    key: Value under which to store meta-data in packing that will be used
-      later to restore the grad_var list structure.
-    packing: Dict holding data describing packed ranges of small tensors.
-    grad_vars: List of (grad, var) pairs for one replica.
-    rng: A pair of integers giving the first, last indices of a consecutive
-      range of tensors to be packed.
-
-  Returns:
-    A tensor that is the concatenation of all the specified small tensors.
-  """
-  to_pack = grad_vars[rng[0]:rng[1] + 1]
-  members = []
-  variables = []
-  restore_shapes = []
-  with ops.name_scope('pack'):
-    for g, v in to_pack:
-      variables.append(v)
-      restore_shapes.append(g.shape)
-      with ops.device(g.device):
-        members.append(array_ops.reshape(g, [-1]))
-    packing[key] = GradPackTuple(
-        indices=range(rng[0], rng[1] + 1),
-        vars=variables,
-        shapes=restore_shapes)
-    with ops.device(members[0].device):
-      return array_ops.concat(members, 0)
-
-
-def unpack_grad_tuple(gv, gpt):
-  """Unpack a previously packed collection of gradient tensors.
-
-  Args:
-    gv: A (grad, var) pair to be unpacked.
-    gpt: A GradPackTuple describing the packing operation that produced gv.
-
-  Returns:
-    A list of (grad, var) pairs corresponding to the values that were
-     originally packed into gv, maybe following subsequent operations like
-     reduction.
-  """
-  elt_widths = [x.num_elements() for x in gpt.shapes]
-  with ops.device(gv[0].device):
-    with ops.name_scope('unpack'):
-      splits = array_ops.split(gv[0], elt_widths)
-      unpacked_gv = []
-      for idx, s in enumerate(splits):
-        unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]),
-                            gpt.vars[idx]))
-  return unpacked_gv
-
-
-def pack_small_tensors(replica_grads, max_bytes=0, max_group=0):
-  """Concatenate small gradient tensors together for reduction.
-
-  Args:
-    replica_grads: List of lists of (gradient, variable) tuples.
-    max_bytes: Int giving max number of bytes in a tensor that
-      may be considered small.
-    max_group: Int giving max number of small tensors that may be
-      concatenated into one new tensor.
-
-  Returns:
-    new_replica_grads, packing where new_replica_grads is identical to
-      replica_grads except that all feasible small_tensors have been removed
-      from their places and concatenated into larger tensors that are
-      now in the front of the list for each replica, and packing contains
-      the data necessary to restore the replica_grads structure.
-
-  Look through the first replica for gradients of the same type (float),
-  and small size, that are all sequential.  For each such group,
-  replace by a new tensor that is a flattened concatenation.  Note
-  that the corresponding variable will be absent, which doesn't matter
-  because it isn't used during all-reduce.
-
-  Requires:
-    Every gv_list in replicas must have isomorphic structure including identical
-      tensor sizes and types.
-  """
-  small_indices = []
-  large_indices = []
-  for idx, (g, _) in enumerate(replica_grads[0]):
-    if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes:
-      small_indices.append(idx)
-    else:
-      large_indices.append(idx)
-  small_ranges, small_singles = extract_ranges(
-      small_indices, range_size_limit=max_group)
-  large_indices = sorted(large_indices + small_singles)
-  num_gv = len(replica_grads[0])
-  packing = {}
-  if small_ranges:
-    new_replica_grads = []
-    for dev_idx, gv_list in enumerate(replica_grads):
-      assert len(gv_list) == num_gv
-      new_gv_list = []
-      for r in small_ranges:
-        key = '%d:%d' % (dev_idx, len(new_gv_list))
-        new_gv_list.append((pack_range(key, packing, gv_list, r),
-                            'packing_var_placeholder'))
-      for i in large_indices:
-        new_gv_list.append(gv_list[i])
-      new_replica_grads.append(new_gv_list)
-    return new_replica_grads, packing
-  else:
-    return replica_grads, None
-
-
-def unpack_small_tensors(replica_grads, packing):
-  """Undo the structure alterations to replica_grads done by pack_small_tensors.
-
-  Args:
-    replica_grads: List of List of (grad, var) tuples.
-    packing: A dict generated by pack_small_tensors describing the changes
-      it made to replica_grads.
-
-  Returns:
-    new_replica_grads: identical to replica_grads except that concatenations
-      of small tensors have been split apart and returned to their original
-      positions, paired with their original variables.
-  """
-  if not packing:
-    return replica_grads
-  new_replica_grads = []
-  num_devices = len(replica_grads)
-  num_packed = len(packing.keys()) // num_devices
-  for dev_idx, gv_list in enumerate(replica_grads):
-    gv_list = list(gv_list)
-    new_gv_list = gv_list[num_packed:]
-    for i in range(num_packed):
-      k = '%d:%d' % (dev_idx, i)
-      gpt = packing[k]
-      gv = unpack_grad_tuple(gv_list[i], gpt)
-      for gi, idx in enumerate(gpt.indices):
-        assert idx == gpt.indices[gi]
-        new_gv_list.insert(idx, gv[gi])
-    new_replica_grads.append(new_gv_list)
-  return new_replica_grads
-
-
 def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
   """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
   if any(isinstance(v, ops.IndexedSlices) for v in values):
@@ -875,18 +547,6 @@
   return result
 
 
-def contains_indexed_slices(value):
-  """Check whether the value is `IndexedSlices` or contains `IndexedSlices`."""
-  if isinstance(value, ops.IndexedSlices):
-    return True
-  elif isinstance(value, (list, tuple)) and value:
-    return any(contains_indexed_slices(v) for v in value)
-  elif isinstance(value, value_lib.DistributedValues):
-    return contains_indexed_slices(value.values)
-  else:
-    return False
-
-
 def is_indexed_slices(value):
   if isinstance(value, ops.IndexedSlices):
     return True
diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py
index 9781bf6..626ec5c 100644
--- a/tensorflow/python/distribute/cross_device_utils_test.py
+++ b/tensorflow/python/distribute/cross_device_utils_test.py
@@ -81,32 +81,7 @@
   def testIsIndexedSlices(self):
     t = math_ops._as_indexed_slices(
         constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
-    self.assertTrue(cross_device_utils.contains_indexed_slices(t))
-
-  @test_util.run_in_graph_and_eager_modes
-  def testContainsIndexedSlices_List(self):
-    t0 = math_ops._as_indexed_slices(
-        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
-    t1 = math_ops._as_indexed_slices(
-        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
-    self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1]))
-
-  @test_util.run_in_graph_and_eager_modes
-  def testContainsIndexedSlices_Tuple(self):
-    t0 = math_ops._as_indexed_slices(
-        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
-    t1 = math_ops._as_indexed_slices(
-        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
-    self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1)))
-
-  @test_util.run_in_graph_and_eager_modes
-  def testContainsIndexedSlices_PerReplica(self):
-    t0 = math_ops._as_indexed_slices(
-        constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
-    t1 = math_ops._as_indexed_slices(
-        constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
-    per_replica = value_lib.PerReplica((t0, t1))
-    self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica))
+    self.assertTrue(cross_device_utils.is_indexed_slices(t))
 
   @combinations.generate(combinations.combine(
       mode=["graph", "eager"],
diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py
index 3103d73..a835f5e 100644
--- a/tensorflow/python/distribute/custom_training_loop_input_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_input_test.py
@@ -636,6 +636,34 @@
       combinations.combine(
           distribution=strategy_combinations.multidevice_strategies,
           mode=["eager"]))
+  def testSegmentSumWithDynamicNumberOfSegments(self, distribution):
+
+    def dataset_fn(_):
+      data = array_ops.zeros(5, dtype=dtypes.int32)
+      dataset = get_dataset_from_tensor_slices(data)
+      dataset = dataset.batch(3)
+      return dataset
+
+    input_iterator = iter(
+        distribution.experimental_distribute_datasets_from_function(dataset_fn))
+
+    @def_function.function
+    def step_fn(example):
+      segment_ids = array_ops.zeros_like_v2(example)
+      num_segment = array_ops.shape(example)[0]
+      # If number of segments is dynamic, output should be a dynamic shape.
+      return math_ops.unsorted_segment_sum(example, segment_ids, num_segment)
+
+    # This assumes that there are exactly 2 replicas
+    outputs = distribution.experimental_local_results(
+        distribution.run(step_fn, args=(next(input_iterator),)))
+    self.assertAllEqual((3,), outputs[0].shape)
+    self.assertAllEqual((2,), outputs[1].shape)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=strategy_combinations.multidevice_strategies,
+          mode=["eager"]))
   def testReshapeWithDynamicInputs(self, distribution):
 
     def dataset_fn(_):
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 173caa3..bd24ec0 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -2332,16 +2332,18 @@
     <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
 
     Args:
-      reduce_op: a `tf.distribute.ReduceOp` or string. How to reduce the value.
-      value: a `tf.distribute.DistributedValue`, or a `tf.Tensor` like object.
-      destinations: a `tf.distribute.DistributedValue`, a `tf.Variable`, a
+      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
+        be combined. Allows using string representation of the enum such as
+        "SUM", "MEAN".
+      value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
+      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
         `tf.Tensor` alike object, or a device string. It specifies the devices
         to reduce to. To perform an all-reduce, pass the same to `value` and
         `destinations`. Note that if it's a `tf.Variable`, the value is reduced
-        to the devices of that variable, this method doesn't update the variable.
-      experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints
-        to perform collective operations. See
-        `tf.distrbute.experimental.CollectiveHints` for details.
+        to the devices of that variable, and this method doesn't update the
+        variable.
+      experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
+        `tf.distribute.experimental.CollectiveHints` for details.
 
     Returns:
       A tensor or value reduced to `destinations`.
@@ -2413,11 +2415,13 @@
     <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
 
     Args:
-      reduce_op: a `tf.distribute.ReduceOp`. How to reduce the value.
+      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
+        be combined. Allows using string representation of the enum such as
+        "SUM", "MEAN".
       value_destination_pairs: a sequence of (value, destinations) pairs. See
-        `reduce_to()` for descriptions.
-      experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints
-        to perform collective operations.
+        `tf.distribute.Strategy.reduce_to` for descriptions.
+      experimental_hints: a `tf.distribute.experimental.CollectiveHints`. See
+        `tf.distribute.experimental.CollectiveHints` for details.
 
     Returns:
       A list of reduced values, one per pair in `value_destination_pairs`.
@@ -3010,32 +3014,64 @@
     return (device_util.current(),)
 
   def all_reduce(self, reduce_op, value, experimental_hints=None):
-    """All-reduces the given `value Tensor` nest across replicas.
+    """All-reduces `value` across all replicas.
 
-    If `all_reduce` is called in any replica, it must be called in all replicas.
-    The nested structure and `Tensor` shapes must be identical in all replicas.
+    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+    >>> def step_fn():
+    ...   ctx = tf.distribute.get_replica_context()
+    ...   value = tf.identity(1.)
+    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
+    >>> strategy.experimental_local_results(strategy.run(step_fn))
+    (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
+     <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
 
-    IMPORTANT: The ordering of communications must be identical in all replicas.
+    It supports batched operations. You can pass a list of values and it
+    attempts to batch them when possible. You can also specify `experimental_hints`
+    to indicate the desired batching behavior, e.g. batch the values into
+    multiple packs so that they can better overlap with computations.
 
-    Example with two replicas:
-      Replica 0 `value`: {'a': 1, 'b': [40, 1]}
-      Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
+    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+    >>> def step_fn():
+    ...   ctx = tf.distribute.get_replica_context()
+    ...   value1 = tf.identity(1.)
+    ...   value2 = tf.identity(2.)
+    ...   return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
+    >>> strategy.experimental_local_results(strategy.run(step_fn))
+    ([PerReplica:{
+      0: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
+      1: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
+    }, PerReplica:{
+      0: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>,
+      1: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
+    }],)
 
-      If `reduce_op` == `SUM`:
-        Result (on all replicas): {'a': 4, 'b': [42, 99]}
+    Note that all replicas need to participate in the all-reduce, otherwise this
+    operation hangs. Note that if there're multiple all-reduces, they need to
+    execute in the same order on all replicas. Dispatching all-reduce based on
+    conditions is usually error-prone.
 
-      If `reduce_op` == `MEAN`:
-        Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
+    This API currently can only be called in the replica context. Other
+    variants to reduce values across replicas are:
+    * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
+      in the cross-replica context.
+    * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
+      all-reduce API in the cross-replica context.
+    * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
+      to the host in cross-replica context.
 
     Args:
-      reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
-      value: The nested structure of `Tensor`s to all-reduce. The structure must
-        be compatible with `tf.nest`.
-      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
+      reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
+        be combined. Allows using string representation of the enum such as
+        "SUM", "MEAN".
+      value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts.
+        The structure and the shapes of the `tf.Tensor` need to be same on all
+        replicas.
+      experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints
         to perform collective operations.
 
     Returns:
-       A `Tensor` nest with the reduced `value`s from each replica.
+       A nested structure of `tf.Tensor` with the reduced values. The structure
+       is the same as `value`.
     """
     if isinstance(reduce_op, six.string_types):
       reduce_op = reduce_util.ReduceOp(reduce_op.upper())
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index b77739c..d689346 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -499,21 +499,27 @@
     return InputWorkers(worker_device_pairs)
 
 
-def _get_next_as_optional(iterator, strategy, name=None):
-  """Returns an empty dataset indicator and the next input from the iterator."""
+def _get_next_as_optional(iterator, strategy, return_per_replica=False):
+  """Returns an empty dataset indicator and the next input from the iterator.
+
+  Args:
+    iterator: a DistributedIterator object.
+    strategy: the `tf.distribute.Strategy` instance.
+    return_per_replica: a boolean. If True, the returned data will be wrapped
+      with `PerReplica` structure. Otherwise it is a 2D
+      num_input_workers*num_replicas_per_worker list.
+
+  Returns:
+    A tuple (a boolean tensor indicating whether the next batch has value
+    globally, data from all replicas).
+  """
   replicas = []
   worker_has_values = []
   worker_devices = []
   for i, worker in enumerate(iterator._input_workers.worker_devices):  # pylint: disable=protected-access
-    if name is not None:
-      d = tf_device.DeviceSpec.from_string(worker)
-      new_name = "%s_%s_%d" % (name, d.job, d.task)
-    else:
-      new_name = None
-
     with ops.device(worker):
       worker_has_value, next_element = (
-          iterator._iterators[i].get_next_as_list(new_name))  # pylint: disable=protected-access
+          iterator._iterators[i].get_next_as_list())  # pylint: disable=protected-access
       # Collective all-reduce requires explicit devices for inputs.
       with ops.device("/cpu:0"):
         # Converting to integers for all-reduce.
@@ -523,6 +529,12 @@
       # Make `replicas` a flat list of values across all replicas.
       replicas.append(next_element)
 
+  if return_per_replica:
+    flattened_data = []
+    for per_worker_data in replicas:
+      flattened_data.extend(per_worker_data)
+    replicas = distribute_utils.regroup(flattened_data)
+
   # Run an all-reduce to see whether any worker has values.
   # TODO(b/131423105): we should be able to short-cut the all-reduce in some
   # cases.
@@ -622,29 +634,15 @@
     return self
 
   def get_next_as_optional(self):
-    global_has_value, replicas = _get_next_as_optional(self, self._strategy)
+    global_has_value, replicas = _get_next_as_optional(
+        self, self._strategy, return_per_replica=True)
 
     def return_none():
       return optional_ops.Optional.empty(self._element_spec)
 
-    def return_value(replicas):
-      """Wraps the inputs for replicas in an `tf.experimental.Optional`."""
-      results = []
-      for i, worker in enumerate(self._input_workers.worker_devices):
-        with ops.device(worker):
-          devices = self._input_workers.compute_devices_for_worker(i)
-          for j, device in enumerate(devices):
-            with ops.device(device):
-              result = replicas[i][j]
-              results.append(result)
-      replicas = results
-
-      return optional_ops.Optional.from_value(
-          distribute_utils.regroup(replicas))
-
-    return control_flow_ops.cond(global_has_value,
-                                 lambda: return_value(replicas),
-                                 lambda: return_none())  # pylint: disable=unnecessary-lambda
+    return control_flow_ops.cond(
+        global_has_value, lambda: optional_ops.Optional.from_value(replicas),
+        return_none)
 
   def get_next(self, name=None):
     """Returns the next input from the iterator for all replicas."""
@@ -671,7 +669,8 @@
       out_of_range_replicas.append(data)
       return data
 
-    global_has_value, replicas = _get_next_as_optional(self, self._strategy)
+    global_has_value, replicas = _get_next_as_optional(
+        self, self._strategy, return_per_replica=False)
     results = []
     for i, worker in enumerate(self._input_workers.worker_devices):
       with ops.device(worker):
@@ -906,7 +905,8 @@
   def reduce(self, initial_state, reduce_fn):
     """Execute a `reduce_fn` over all the elements of the input."""
     iterator = iter(self)
-    has_data, data = _get_next_as_optional(iterator, self._strategy)
+    has_data, data = _get_next_as_optional(
+        iterator, self._strategy, return_per_replica=True)
 
     def cond(has_data, data, state):
       del data, state  # Unused.
@@ -915,16 +915,9 @@
     def loop_body(has_data, data, state):
       """Executes `reduce_fn` in a loop till the dataset is empty."""
       del has_data  # Unused.
-      # data is list of lists here. where each list corresponds to one worker.
-      # TODO(b/130570614): Add support for the multiworker and TPU pods use
-      # case.
-      if self._input_workers.num_workers == 1:
-        data = data[0]
-      else:
-        raise ValueError("Dataset iteration within a tf.function is"
-                         " not supported for multiple workers.")
-      state = reduce_fn(state, distribute_utils.regroup(data))
-      has_data, data = _get_next_as_optional(iterator, self._strategy)
+      state = reduce_fn(state, data)
+      has_data, data = _get_next_as_optional(
+          iterator, self._strategy, return_per_replica=True)
       return has_data, data, state
 
     has_data, data, final_state = control_flow_ops.while_loop(
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index a70eb50..ea41154 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -992,8 +992,7 @@
 
     strategy = mirrored_strategy.MirroredStrategy(
         devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
-        cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
-            ["/job:worker/task:0", "/job:worker/task:1"], 1))
+        cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
     worker_devices = self._cpu_devices()
     with context.graph_mode(), self.cached_session() as sess:
       if auto_shard_policy == AutoShardPolicy.AUTO:
@@ -1022,8 +1021,7 @@
 
     strategy = mirrored_strategy.MirroredStrategy(
         devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
-        cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
-            ["/job:worker/task:0", "/job:worker/task:1"], 1))
+        cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
     worker_devices = self._cpu_devices()
     with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
 
@@ -1064,8 +1062,7 @@
     strategy = mirrored_strategy.MirroredStrategy(
         devices=(self._cpu_and_one_gpu_devices()[0][1] +
                  self._cpu_and_one_gpu_devices()[1][1]),
-        cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
-            ["/job:worker/task:0", "/job:worker/task:1"], 2))
+        cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
     worker_devices = self._cpu_and_one_gpu_devices()
     with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
 
@@ -1097,8 +1094,7 @@
                        enable_get_next_as_optional):
     strategy = mirrored_strategy.MirroredStrategy(
         devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
-        cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
-            ["/job:worker/task:0", "/job:worker/task:1"], 1))
+        cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
     worker_devices = self._cpu_devices()
 
     def dataset_fn(ctx):
@@ -1144,8 +1140,7 @@
     strategy = mirrored_strategy.MirroredStrategy(
         devices=(self._cpu_and_one_gpu_devices()[0][1] +
                  self._cpu_and_one_gpu_devices()[1][1]),
-        cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
-            ["/job:worker/task:0", "/job:worker/task:1"], 2))
+        cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
     if tf2.enabled():
       dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2)
     else:
@@ -1263,8 +1258,7 @@
     strategy = mirrored_strategy.MirroredStrategy(
         devices=(self._cpu_and_one_gpu_devices()[0][1] +
                  self._cpu_and_one_gpu_devices()[1][1]),
-        cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
-            ["/job:worker/task:0", "/job:worker/task:1"], 2))
+        cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
     worker_devices = self._cpu_and_one_gpu_devices()
     with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
 
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 07798dc..79a5636 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -86,7 +86,7 @@
   for task_type in ("chief", "worker"):
     for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
       if num_gpus_per_worker == 0:
-        devices.append("/job:%s/task:%d" % (task_type, task_id))
+        devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id))
       else:
         devices.extend([
             "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
@@ -378,8 +378,10 @@
     self._is_multi_worker_training = True
 
     if len(workers) > 1:
-      if not isinstance(self._cross_device_ops,
-                        cross_device_ops_lib.MultiWorkerAllReduce):
+      # Grandfather usage in the legacy tests if they're configured properly.
+      if (not isinstance(self._cross_device_ops,
+                         cross_device_ops_lib.ReductionToOneDevice) or
+          self._cross_device_ops._num_between_graph_workers > 1):  # pylint: disable=protected-access
         raise ValueError(
             "In-graph multi-worker training with `MirroredStrategy` is not "
             "supported.")
diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py
index 5c86cbe..acdfdbb 100644
--- a/tensorflow/python/distribute/mirrored_strategy_test.py
+++ b/tensorflow/python/distribute/mirrored_strategy_test.py
@@ -1148,9 +1148,9 @@
                 # pylint: disable=g-long-lambda
                 lambda: mirrored_strategy.MirroredStrategy(
                     devices=mirrored_strategy.all_local_devices(),
-                    cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce([
-                        "/job:worker/task:0", "/job:worker/task:1"
-                    ], context.num_gpus())),
+                    cross_device_ops=cross_device_ops_lib.ReductionToOneDevice(
+                    ),
+                ),
                 required_gpus=1)
         ],
         mode=["graph"]))
@@ -1288,9 +1288,7 @@
     cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
 
   def _make_cross_device_ops(self):
-    return cross_device_ops_lib.MultiWorkerAllReduce(
-        ["/job:chief/task:0", "/job:worker/task:0", "/job:worker/task:1"],
-        context.num_gpus())
+    return cross_device_ops_lib.ReductionToOneDevice()
 
   def testMinimizeLossGraph(self):
     with context.graph_mode():
diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py
index 408cad2..b0c51f4 100644
--- a/tensorflow/python/distribute/multi_worker_test_base.py
+++ b/tensorflow/python/distribute/multi_worker_test_base.py
@@ -41,6 +41,9 @@
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import multi_process_runner
+from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
+from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
 from tensorflow.python.eager import context
 from tensorflow.python.eager import remote
 from tensorflow.python.framework import errors
@@ -200,6 +203,156 @@
   return cluster
 
 
+class MultiProcessCluster(object):
+  """A cluster of TensorFlow servers in separate processes.
+
+  This class is not thread-safe.
+  """
+
+  def __init__(self, cluster_resolver):
+    self._cluster_resolver = cluster_resolver
+    self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
+    self._rpc_layer = cluster_resolver.rpc_layer
+    self._start_events = {}
+    self._finish_events = {}
+    self._mpr_manager = multi_process_runner.manager()
+
+    def task_function(start_events, finish_events):
+      cluster_resolver = TFConfigClusterResolver()
+      cluster_spec = cluster_resolver.cluster_spec()
+      task_type = cluster_resolver.task_type
+      task_id = cluster_resolver.task_id
+      rpc_layer = cluster_resolver.rpc_layer
+
+      logging.info(
+          'Starting server with cluster_spec = %r, task_type = %r, '
+          'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
+          rpc_layer)
+
+      # TODO(yuefengz): support GPU clusters.
+      server_config = config_pb2.ConfigProto()
+      server_config.device_count['GPU'] = 0
+
+      server_lib.Server(
+          cluster_spec,
+          job_name=task_type,
+          protocol=rpc_layer,
+          task_index=task_id,
+          config=server_config,
+          start=True)
+
+      start_event = start_events[task_type][task_id]
+      start_event.set()
+
+      finish_event = finish_events[task_type][task_id]
+      finish_event.wait()
+
+      os._exit(0)  # pylint: disable=protected-access
+
+    self._task_function = task_function
+    self._mpr = None
+
+  def start(self):
+    """Starts one TensorFlow server for each task in the cluster_resolver.
+
+    It will wait until all the servers are up before returns.
+    """
+    if self._mpr:
+      raise ValueError('The cluster has already been started.')
+    for task_type, task_addresses in self._cluster_spec.items():
+      self._start_events[task_type] = []
+      self._finish_events[task_type] = []
+      for _ in task_addresses:
+        self._start_events[task_type].append(self._mpr_manager.Event())
+        self._finish_events[task_type].append(self._mpr_manager.Event())
+
+    self._mpr = multi_process_runner.MultiProcessRunner(
+        self._task_function,
+        self._cluster_spec,
+        args=(self._start_events, self._finish_events),
+        rpc_layer=self._rpc_layer,
+        stream_stdout=False,
+        list_stdout=False,
+        use_dill_for_args=False)
+    self._mpr.start()
+    for task_type, task_addresses in self._cluster_spec.items():
+      for i in range(len(task_addresses)):
+        self._start_events[task_type][i].wait()
+
+  def stop(self):
+    """Stops all the servers."""
+    for task_type, task_addresses in self._cluster_spec.items():
+      for i in range(len(task_addresses)):
+        self._finish_events[task_type][i].set()
+    try:
+      self._mpr.join()
+    except multi_process_runner.UnexpectedSubprocessExitError:
+      # TODO(yuefengz): investigate why processes exit with 255.
+      pass
+    self._mpr = None
+    self._start_events = {}
+    self._finish_events = {}
+
+  def kill_task(self, task_type, task_id):
+    """Kill a server given task_type and task_id.
+
+    Args:
+      task_type: the type of the task such as "worker".
+      task_id: the id the task such as 1.
+    """
+    assert self._mpr
+    if (not self._start_events[task_type][task_id].is_set() or
+        self._finish_events[task_type][task_id].is_set()):
+      raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id))
+
+    self._finish_events[task_type][task_id].set()
+    self._mpr._processes[(task_type, task_id)].join()
+
+  def start_task(self, task_type, task_id):
+    """Starts a server given task_type and task_id.
+
+    Args:
+      task_type: the type of the task such as "worker".
+      task_id: the id the task such as 1.
+
+    Raises:
+      ValueError: if the server alreay exists.
+    """
+    assert self._mpr
+
+    if (not self._start_events[task_type][task_id].is_set() or
+        not self._finish_events[task_type][task_id].is_set()):
+      raise ValueError(
+          'The task %s:%d is still alive. You cannot start another one.' %
+          (task_type, task_id))
+    self._start_events[task_type][task_id] = self._mpr_manager.Event()
+    self._finish_events[task_type][task_id] = self._mpr_manager.Event()
+    self._mpr.start_single_process(task_type=task_type, task_id=task_id)
+    self._start_events[task_type][task_id].wait()
+
+  @property
+  def cluster_resolver(self):
+    return copy.deepcopy(self._cluster_resolver)
+
+
+def create_multi_process_cluster(num_workers,
+                                 num_ps,
+                                 has_chief=False,
+                                 has_eval=False,
+                                 rpc_layer='grpc'):
+  cluster_spec = create_cluster_spec(
+      has_chief=has_chief,
+      num_workers=num_workers,
+      num_ps=num_ps,
+      has_eval=has_eval)
+
+  cluster = MultiProcessCluster(
+      SimpleClusterResolver(
+          server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer))
+  cluster.start()
+  return cluster
+
+
 # TODO(rchao): Remove `test_obj` once estimator repo picks up the updated
 # nightly TF.
 def create_cluster_spec(has_chief=False,
diff --git a/tensorflow/python/distribute/multi_worker_test_base_test.py b/tensorflow/python/distribute/multi_worker_test_base_test.py
new file mode 100644
index 0000000..e660d28
--- /dev/null
+++ b/tensorflow/python/distribute/multi_worker_test_base_test.py
@@ -0,0 +1,82 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multi-process clusters."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.distribute import multi_process_runner
+from tensorflow.python.distribute import multi_worker_test_base
+from tensorflow.python.eager import context
+from tensorflow.python.eager import remote
+from tensorflow.python.eager import test
+
+
+class MultiProcessClusterTest(test.TestCase):
+
+  def setUp(self):
+    super(MultiProcessClusterTest, self).setUp()
+    self._cluster = multi_worker_test_base.create_multi_process_cluster(
+        num_workers=2, num_ps=1, has_chief=True, rpc_layer="grpc")
+    remote.connect_to_cluster(
+        self._cluster.cluster_resolver.cluster_spec(), protocol="grpc")
+    context.ensure_initialized()
+
+  def testClusterIsAlive(self):
+    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
+    self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
+    self.assertTrue(context.check_alive("/job:ps/replica:0/task:0"))
+    self.assertTrue(context.check_alive("/job:chief/replica:0/task:0"))
+
+  def testKillAndStartTask(self):
+    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
+
+    # It is not allowed to start a task before killing it.
+    with self.assertRaises(ValueError):
+      self._cluster.start_task("worker", 0)
+
+    self._cluster.kill_task("worker", 0)
+    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
+
+    # The task is already killed.
+    with self.assertRaises(ValueError):
+      self._cluster.kill_task("worker", 0)
+
+    self._cluster.start_task("worker", 0)
+
+    # Without a call to update_server_def, the next check_alive will return
+    # False. Alternatively sleeping for 2 seconds here also works.
+    context.context().update_server_def(context.get_server_def())
+
+    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
+
+  def testStop(self):
+    self._cluster.stop()
+    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
+    self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
+    self.assertFalse(context.check_alive("/job:ps/replica:0/task:0"))
+    self.assertFalse(context.check_alive("/job:chief/replica:0/task:0"))
+
+  def testClusterResolverProperty(self):
+    cluster_spec = self._cluster.cluster_resolver.cluster_spec().as_dict()
+
+    self.assertEqual(len(cluster_spec["worker"]), 2)
+    self.assertEqual(len(cluster_spec["ps"]), 1)
+    self.assertEqual(len(cluster_spec["chief"]), 1)
+
+
+if __name__ == "__main__":
+  multi_process_runner.test_main()
diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py
index 718fa80..199c676 100644
--- a/tensorflow/python/distribute/parameter_server_strategy_v2.py
+++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py
@@ -30,8 +30,10 @@
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_inspect
 
 
 # pylint: disable=protected-access
@@ -48,48 +50,47 @@
   is subject to changes.
   """
 
-  def __init__(self, cluster_resolver):
+  def __init__(self, cluster_resolver, variable_partitioner=None):
     """Initializes the V2 parameter server strategy.
 
     Args:
       cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
         object.
+      variable_partitioner: a callable with the signature `num_partitions =
+        fn(shape, dtype)`, where `num_partitions` is a list/tuple representing
+        the number of partitions on each axis, and `shape` and `dtype` are of
+        types `tf.TensorShape` and `tf.dtypes.Dtype`. If None, variables will
+        not be partitioned.
+        * `variable_partitioner` will be called for all variables created under
+        strategy `scope` to instruct how the variables should be partitioned.
+        Variables will be partitioned if there are more than one partitions
+        along the partitioning axis, otherwise it falls back to normal
+        `tf.Variable`.
+        * Only the first / outermost axis partitioning is supported, namely,
+        elements in `num_partitions` must be 1 other than the first element.
+        * Partitioner like `min_max_variable_partitioner`,
+        `variable_axis_size_partitioner` and `fixed_size_partitioner` are also
+        supported since they conform to the required signature.
+        * Div partition strategy is used to partition variables.
+        Assuming we assign consecutive integer ids along the first axis of a
+        variable, then ids are assigned to shards in a contiguous manner, while
+        attempting to keep each shard size identical. If the ids do not evenly
+        divide the number of shards, each of the first several shards will be
+        assigned one more id. For instance, a variable whose first dimension is
+        13 has 13 ids, and they are split across 5 shards as:
+        `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
+        * Variables created under `strategy.extended.colocate_vars_with` will
+        not be partitioned, e.g, optimizer's slot variables.
     """
-    self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver)
     self._cluster_resolver = cluster_resolver
+    self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
+                                                       variable_partitioner)
     self._verify_args_and_config(cluster_resolver)
     logging.info(
         "ParameterServerStrategyV2 is initialized with cluster_spec: "
         "%s", cluster_resolver.cluster_spec())
     super(ParameterServerStrategyV2, self).__init__(self._extended)
 
-  @tf_contextlib.contextmanager
-  def experimental_variable_partitioning_scope(self):
-    """A context manager for creating `ShardedVariable`.
-
-    Variables created inside a `with experimental_variable_partitioning_scope()`
-    code block will be of type `ShardedVariable` and their values are
-    partitioned among parameter servers along the first / outermost axis. The
-    number of shards are equal to the number of parameter servers.
-
-    Variables created within this scope must be initialized using a callable as
-    `initial_value` and a known shape.
-
-    Div partition strategy is used to partition the variable. Assuming we
-    assign consective integer ids along the first axis of the variable, then ids
-    are assigned to shards in a contiguous manner, while attempting to keep each
-    shard size identical. If the ids do not evenly divide the number of shards,
-    each of the first several shards will be assigned one more id. For instance,
-    a variable whose first dimension is 13 has 13 ids, and they are split across
-    5 shards as: `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
-
-    Yields:
-      A context manager for creating `ShardedVariable`.
-    """
-    with variable_scope.variable_creator_scope(
-        self._extended._make_sharded_variable_creator()):
-      yield
-
   def _verify_args_and_config(self, cluster_resolver):
     if not cluster_resolver.cluster_spec():
       raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
@@ -104,11 +105,19 @@
   Please see `tf.distribute.StrategyExtended` doc for more information.
   """
 
-  def __init__(self, container_strategy, cluster_resolver):
+  def __init__(self, container_strategy, cluster_resolver,
+               variable_partitioner):
     """Initialization of ParameterServerStrategyV2Extended."""
     super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
     self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
     self._variable_count = 0
+    self._variable_partitioner = variable_partitioner
+
+  @tf_contextlib.contextmanager
+  def _scope(self, strategy):
+    with super()._scope(strategy):
+      with variable_scope.variable_creator_scope(self._make_variable_creator()):
+        yield
 
   def _create_variable(self, next_creator, **kwargs):
 
@@ -136,69 +145,114 @@
         self._variable_count += 1
         return var
 
-  def _make_sharded_variable_creator(self):
+  def _make_variable_creator(self):
     """Returns a function conforming to the `variable_creator` signature.
 
-    The returned function creates `ShardedVariable` when called.
+    The returned function creates `ShardedVariable` or `Variable` when called.
+    `ShardedVariable` will be created if satisfying all the following criteria:
+      1. `self._variable_partitioner` results in more than one partition on the
+         first axis.
+      2. variable's rank is greater than 0.
+      3. variable is not colocated with other variables.
+    Otherwise `Variable` will be created.
     """
 
-    def sharded_variable_creator(next_creator, **kwargs):
-      if "shape" not in kwargs or kwargs["shape"] is None:
-        raise ValueError("shape must be explicitly specified when creating "
-                         "sharded variables")
-      init_fn = kwargs.get("initial_value", None)
-      # We intentionally don't allow non-callable initial_value to ensure the
-      # value is created on PS but not client. If the value is created on
-      # client, it will needed to be sent to PS for variable initialization,
-      # which is inefficient and can potentially hit the 2GB limit on protobuf
-      # serialization.
-      if init_fn is None or not callable(init_fn):
-        raise ValueError("initial_value must be specified as a callable when "
-                         "creating sharded variables")
+    def variable_creator(next_creator, **kwargs):
+      if self._variable_partitioner is None:
+        return next_creator(**kwargs)
+
+      if "colocate_with" in kwargs:  # Never partition colocated_with variables.
+        return next_creator(**kwargs)
+
+      name = kwargs.get("name", None)
+      initial_value = kwargs.get("initial_value", None)
+      if initial_value is None:
+        raise ValueError("initial_value must be specified.")
+      init_from_fn = callable(initial_value)
+
+      dtype = kwargs.get("dtype", None)
+      shape = kwargs.get("shape", None)
+      if init_from_fn and (shape is None or dtype is None):
+        init_from_fn = False
+        initial_value = initial_value()
+      if not init_from_fn:
+        # The initial_value is created on client, it will need to be sent to
+        # PS for variable initialization, which can be inefficient and can
+        # potentially hit the 2GB limit on protobuf serialization.
+        initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
+        dtype = initial_value.dtype
+        shape = initial_value.shape
+      else:
+        shape = tensor_shape.as_shape(shape)
+
+      if shape.rank == 0:  # Skip partitioning rank-0 variable.
+        return next_creator(**kwargs)
+
+      num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
+      if not num_partitions or num_partitions[0] == 0 or any(
+          v != 1 for v in num_partitions[1:]):
+        raise ValueError(
+            "variable_partitioner must return a list/tuple whose elements are 1"
+            " besides the first element (non-zero), got: %r" % num_partitions)
+
+      if num_partitions[0] == 1:  # no partition
+        return next_creator(**kwargs)
 
       # Use "div" partition strategy to partition the variable.
-      full_shape = kwargs["shape"]
-      if self._num_ps < full_shape[0]:
-        num_shards = self._num_ps
-      else:
-        num_shards = full_shape[0]
+      num_partitions = min(num_partitions[0], shape[0])
+      base = shape[0] // num_partitions
+      extra = shape[0] % num_partitions
+      # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
+      # offsets: [0, 3, 6, 8, 10]
       offsets = []
-      base = full_shape[0] // num_shards
-      extra = full_shape[0] % num_shards
-      for i in range(num_shards):
+      for i in range(num_partitions):
         if i == 0:
           offsets.append(0)
         else:
           prev_shard_size = base + (1 if i - 1 < extra else 0)
           offsets.append(offsets[i - 1] + prev_shard_size)
+      offsets.append(shape[0])
 
-      # Note: The way we initialize sharded variables is suboptimal, as it
-      # needs to create the full value tensor separately on each PS which the
-      # variable is going to be placed on. The full value could be very large
-      # and consume a lot of memory. The ideal way is to only create what's
-      # needed on the shard, however that's not practical because:
-      #  1. Initializers don't have sharded behavior support, even though some
-      #     initializers (e.g, uniform) can be used directly.
-      #  2. tf.Variable signature requires "initial_value" to be either a value
-      #     or a callable without arguments, meaning it is not straightforward
-      #     to make the sharded component from it.
       def init_shard_fn(shard_index):
-        full_value = init_fn()
-        if shard_index < num_shards - 1:
+        if not init_from_fn:
+          logging.log_if(
+              logging.WARNING, _INEFFICIENT_INIT_WARNING % name,
+              shard_index == 0 and
+              shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
+          return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
+        arg_spec = tf_inspect.getfullargspec(initial_value)
+        if ("partition" not in arg_spec.args and
+            "partition" not in arg_spec.kwonlyargs):
+          # `initial_value` is a callable that doesn't accept `Partition`.
+          logging.log_if(
+              logging.WARNING, _INEFFICIENT_INIT_WARNING % name,
+              shard_index == 0 and
+              shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
+          full_value = initial_value()
           return full_value[offsets[shard_index]:offsets[shard_index + 1]]
         else:
-          return full_value[offsets[shard_index]:]
+          # Memory-efficient way of initializing sharded variable. It requires
+          # the `init_fn` to accept a namedtuple `Partition`.
+          component_shape = (offsets[shard_index + 1] -
+                             offsets[shard_index],) + shape[1:]
+          offsets_all_axes = (offsets[shard_index],) + (0,) * len(shape[1:])
+          return initial_value(
+              partition=sharded_variable.Partition(
+                  shape=tensor_shape.as_shape(component_shape),
+                  offsets=offsets_all_axes))
 
       var_list = []
-      for i in range(num_shards):
-        kwargs["shape"] = None
+      for i in range(num_partitions):
+        kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
         kwargs["initial_value"] = lambda: init_shard_fn(i)
+        if name is not None:
+          kwargs["name"] = "{}/part_{}".format(name, i)
         var_list.append(next_creator(**kwargs))
 
       result = sharded_variable.ShardedVariable(var_list)
       return result
 
-    return sharded_variable_creator
+    return variable_creator
 
   def _call_for_each_replica(self, fn, args, kwargs):
     with distribute_lib.ReplicaContext(
@@ -206,3 +260,18 @@
         replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
       # TODO(rchao): Support multi-replica per worker or sync-group.
       return distribute_utils.regroup((fn(*args, **kwargs),))
+
+
+# The warning that will be logged if the way we initialize sharded variables
+# is memory-inefficient.
+_INEFFICIENT_INIT_WARNING = (
+    "Large variable %s is partitioned but not initialized in a memory-efficient"
+    " way. The full value is first being created and then sliced into smaller "
+    "values. To reduce the memory footprint, explicitly specify `dtype` and "
+    "`shape` when creating variables, and pass a callable to Variable's "
+    "`initial_value`. The callable should take only one argument which is a "
+    "namedtuple (shape: `tf.TensorShape`, offsets: list/tuple) where shape is "
+    "the shape of the component variable, and offsets is the offsets of the "
+    "smaller variable on each axis.")
+
+_LARGE_VARIABLE_NUM_ELEMENTS = 1e9
diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py
index 1b1e7d8..64e2f24 100644
--- a/tensorflow/python/distribute/parameter_server_strategy_v2_test.py
+++ b/tensorflow/python/distribute/parameter_server_strategy_v2_test.py
@@ -18,11 +18,24 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+
+import functools
+import platform
+import sys
+
 from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute import parameter_server_strategy_v2
+from tensorflow.python.distribute import sharded_variable
 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
+from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.eager import remote
 from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops_v2
+from tensorflow.python.ops import linalg_ops_impl
+from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.ops import variables
 from tensorflow.python.training.server_lib import ClusterSpec
 
@@ -33,15 +46,23 @@
   def setUpClass(cls):
     super(ParameterServerStrategyV2Test, cls).setUpClass()
     cluster_def = multi_worker_test_base.create_in_process_cluster(
-        num_workers=2, num_ps=3, rpc_layer="grpc")
-    cls.cluster_resolver = SimpleClusterResolver(
-        ClusterSpec(cluster_def), rpc_layer="grpc")
+        num_workers=2, num_ps=3)
+    cls.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
     remote.connect_to_cluster(
-        cls.cluster_resolver.cluster_spec(),
-        job_name="chief",
-        protocol=cls.cluster_resolver.rpc_layer)
+        cls.cluster_resolver.cluster_spec(), job_name="chief")
+
+  @classmethod
+  def tearDownClass(cls):
+    super(ParameterServerStrategyV2Test, cls).tearDownClass()
+    # reset context to disconnect from the cluster.
+    context._reset_context()
 
   def testVariablePlacement(self):
+
+    if sys.version_info >= (3, 8) and platform.system() == "Windows":
+      # TODO(b/165013260): Fix this
+      self.skipTest("Test is currently broken on Windows with Python 3.8")
+
     strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
         self.cluster_resolver)
     v1 = variables.Variable(initial_value=0.0)
@@ -59,5 +80,246 @@
     self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0")
 
 
+class PartitionAwareIdentity(object):
+
+  def __call__(self, shape, dtype, partition):
+    value = linalg_ops_impl.eye(*shape, dtype=dtype)
+    if partition is not None:
+      value = array_ops.slice(value, partition.offsets, partition.shape)
+    return value
+
+
+class VariablePartitioningTest(test.TestCase):
+
+  @classmethod
+  def setUpClass(cls):
+    super(VariablePartitioningTest, cls).setUpClass()
+    cluster_def = multi_worker_test_base.create_in_process_cluster(
+        num_workers=2, num_ps=2)
+    cls.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
+    remote.connect_to_cluster(cls.cluster_resolver.cluster_spec())
+
+  @classmethod
+  def tearDownClass(cls):
+    super(VariablePartitioningTest, cls).tearDownClass()
+    # reset context to disconnect from the cluster.
+    context._reset_context()
+
+  def setUp(self):
+    super().setUp()
+    if sys.version_info >= (3, 8) and platform.system() == "Windows":
+      # TODO(b/165013260): Fix this
+      self.skipTest("Test is currently broken on Windows with Python 3.8")
+
+  def testDefaultNoPartition(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver)
+    with strategy.scope():
+      v = variables.Variable([0, 1, 2, 3])
+
+    self.assertIsInstance(v, variables.Variable)
+
+  def testBasic(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2))
+    with strategy.scope():
+      init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
+      v1 = variables.Variable(
+          initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
+          shape=(5, 2),
+          dtype=dtypes.int64)
+
+      init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
+      v2 = variables.Variable(
+          initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
+          shape=(6, 1),
+          dtype=dtypes.int64)
+
+    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
+    self.assertLen(v1.variables, 2)
+    self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
+    self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
+    self.assertAllEqual(v1.variables[0].read_value().numpy(),
+                        [[0, 1], [2, 3], [4, 5]])
+    self.assertAllEqual(v1.variables[1].read_value().numpy(), [[6, 7], [8, 9]])
+
+    self.assertIsInstance(v2, sharded_variable.ShardedVariable)
+    self.assertLen(v2.variables, 2)
+    self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
+    self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
+    self.assertAllEqual(v2.variables[0].read_value().numpy(), [[0], [1], [2]])
+    self.assertAllEqual(v2.variables[1].read_value().numpy(), [[3], [4], [5]])
+
+  def testNonCallableInitialValue(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, partitioned_variables.fixed_size_partitioner(4))
+    with strategy.scope():
+      v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
+
+    self.assertIsInstance(v, sharded_variable.ShardedVariable)
+    self.assertLen(v.variables, 4)
+    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
+    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
+    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
+    self.assertRegex(v.variables[3].device, "/job:ps/replica:0/task:1")
+    self.assertAllEqual(v.variables[0].read_value().numpy(), [0, 1, 2])
+    self.assertAllEqual(v.variables[1].read_value().numpy(), [3, 4, 5])
+    self.assertAllEqual(v.variables[2].read_value().numpy(), [6, 7])
+    self.assertAllEqual(v.variables[3].read_value().numpy(), [8, 9])
+
+  def testNumPartitionsLargerThanSize(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, partitioned_variables.fixed_size_partitioner(4))
+    with strategy.scope():
+      v = variables.Variable([0, 1, 2])
+
+    self.assertIsInstance(v, sharded_variable.ShardedVariable)
+    self.assertLen(v.variables, 3)
+    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
+    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
+    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
+    self.assertAllEqual(v.variables[0].read_value().numpy(), [0])
+    self.assertAllEqual(v.variables[1].read_value().numpy(), [1])
+    self.assertAllEqual(v.variables[2].read_value().numpy(), [2])
+
+  def testPartitionToOne(self):
+    # For small variables there is only one partition.
+    variable_partitioner = partitioned_variables.min_max_variable_partitioner(
+        max_partitions=2, min_slice_size=64 << 20)
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, variable_partitioner)
+    with strategy.scope():
+      initializer = init_ops_v2.Constant([0] * 10)
+      v1 = variables.Variable(
+          initial_value=lambda: initializer(shape=(10,), dtype=dtypes.int64),
+          shape=(10,),
+          dtype=dtypes.int64)
+
+      v2 = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
+
+    self.assertIsInstance(v1, variables.Variable)
+    self.assertNotIsInstance(v1, sharded_variable.ShardedVariable)
+    self.assertRegex(v1.device, "/job:ps/replica:0/task:0")
+    self.assertAllEqual(v1.read_value().numpy(), [0] * 10)
+
+    self.assertIsInstance(v2, variables.Variable)
+    self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
+    self.assertRegex(v2.device, "/job:ps/replica:0/task:1")
+    self.assertAllEqual(v2.read_value().numpy(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
+
+  def testColocateWith(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2))
+    with strategy.scope():
+      v1 = variables.Variable([0, 1, 2, 3])
+
+      with strategy.extended.colocate_vars_with(v1.variables[0]):
+        v2 = variables.Variable([4, 5])
+
+    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
+
+    self.assertIsInstance(v2, variables.Variable)
+    self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
+    self.assertEqual(v2.device, v1.variables[0].device)
+    self.assertAllEqual(v2.read_value().numpy(), [4, 5])
+
+  def testPartitionAwareInitializer(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2))
+    with strategy.scope():
+      initializer = PartitionAwareIdentity()
+      initial_value = functools.partial(
+          initializer, shape=(4, 4), dtype=dtypes.int64)
+      v = variables.Variable(
+          initial_value=initial_value, shape=(4, 4), dtype=dtypes.int64)
+
+    self.assertIsInstance(v, sharded_variable.ShardedVariable)
+    self.assertLen(v.variables, 2)
+    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
+    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
+    self.assertAllEqual(v.variables[0].read_value().numpy(),
+                        [[1, 0, 0, 0], [0, 1, 0, 0]])
+    self.assertAllEqual(v.variables[1].read_value().numpy(),
+                        [[0, 0, 1, 0], [0, 0, 0, 1]])
+
+  def testPartitionWhenLackOfInfo(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2))
+    with strategy.scope():
+      initializer = init_ops_v2.Constant([0, 1, 2, 3])
+      # Shape is not explicitly specified.
+      v1 = variables.Variable(
+          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
+          dtype=dtypes.int64)
+      # Dtype is not explicitly specified.
+      v2 = variables.Variable(
+          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
+          shape=(4,))
+      # Neither shape nor dtype is explicitly specified.
+      v3 = variables.Variable(
+          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64))
+
+    for v in [v1, v2, v3]:
+      self.assertIsInstance(v, sharded_variable.ShardedVariable)
+      self.assertLen(v.variables, 2)
+      self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
+      self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
+      self.assertAllEqual(v.variables[0].read_value().numpy(), [0, 1])
+      self.assertAllEqual(v.variables[1].read_value().numpy(), [2, 3])
+
+  def testInvalidPartitioner(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, lambda shape, dtype: None)
+    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
+      with strategy.scope():
+        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
+
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, lambda shape, dtype: [])
+    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
+      with strategy.scope():
+        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
+
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, lambda shape, dtype: [0, 1, 1])
+    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
+      with strategy.scope():
+        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
+
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, lambda shape, dtype: [2, 2, 1])
+    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
+      with strategy.scope():
+        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
+
+  def testCreateInsideTFFunction(self):
+    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
+        self.cluster_resolver, partitioned_variables.fixed_size_partitioner(2))
+
+    collection = []
+
+    @def_function.function
+    def create_vars():
+      if not collection:
+        identity = init_ops_v2.Identity()
+        v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32)
+        v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32))
+        v3 = variables.Variable(
+            lambda: identity((2, 2), dtypes.float32),
+            dtype=dtypes.float32,
+            shape=(2, 2))
+        collection.extend([v1, v2, v3])
+
+    with strategy.scope():
+      create_vars()
+      for v in collection:
+        self.assertIsInstance(v, sharded_variable.ShardedVariable)
+        self.assertLen(v.variables, 2)
+        self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
+        self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
+        self.assertAllEqual(v.variables[0].read_value().numpy(), [[1., 0.]])
+        self.assertAllEqual(v.variables[1].read_value().numpy(), [[0., 1.]])
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py
index ea0fbf1..c68d3f5 100644
--- a/tensorflow/python/distribute/sharded_variable.py
+++ b/tensorflow/python/distribute/sharded_variable.py
@@ -17,9 +17,12 @@
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import copy
 
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.saved_model import save_context
@@ -122,6 +125,13 @@
           'axis, found {}'.format([v.shape for v in variables]))
     first_dim = sum(int(v.shape[0]) for v in variables)
     self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:])
+    self._var_offsets = [
+        [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
+    ]
+    for i in range(1, len(variables)):
+      # Always partition on the first axis. Offsets on other axes are 0.
+      self._var_offsets[i][0] += (
+          self._var_offsets[i - 1][0] + variables[i - 1].shape[0])
 
     save_slice_info = [v._get_save_slice_info() for v in variables]  # pylint: disable=protected-access
     if any(slice_info is not None for slice_info in save_slice_info):
@@ -162,6 +172,20 @@
     """The overall shape, combining all shards along axis `0`."""
     return self._shape
 
+  def assign(self, value, use_locking=None, name=None, read_value=True):
+    for i, v in enumerate(self._variables):
+      v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
+
+  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
+    for i, v in enumerate(self._variables):
+      v.assign_add(
+          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
+
+  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
+    for i, v in enumerate(self._variables):
+      v.assign_sub(
+          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
+
   def _gather_saveables_for_checkpoint(self):
     """Return a `Saveable` for each shard. See `Trackable`."""
 
@@ -195,3 +219,22 @@
                                     name=self.name)
 
     return obj_map, resource_map
+
+
+def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
+  del name
+  if dtype is not None and not dtype.is_compatible_with(var.dtype):
+    raise ValueError(
+        'Incompatible type conversion requested to type {!r} for variable '
+        'of type {!r}'.format(dtype.name, var.dtype.name))
+  if as_ref:
+    raise NotImplementedError(
+        "ShardedVariable doesn't support being used as a reference.")
+  return array_ops.concat(var.variables, axis=0)
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
+
+Partition = collections.namedtuple('PartitionInfo', ['shape', 'offsets'])
diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py
index 64ed3d0..9bea7b8 100644
--- a/tensorflow/python/distribute/sharded_variable_test.py
+++ b/tensorflow/python/distribute/sharded_variable_test.py
@@ -72,6 +72,44 @@
     self.assertEqual(s.dtype, v0.dtype)
     self.assertEqual(s.name, 's')
 
+  def test_assign(self):
+    v0 = variables_lib.Variable([[0, 0]])
+    v1 = variables_lib.Variable([[1, 1], [2, 2]])
+    v2 = variables_lib.Variable([[3, 3]])
+    s = sharded_variable.ShardedVariable([v0, v1, v2])
+    s.assign([[4, 4], [5, 5], [6, 6], [7, 7]])
+    self.assertAllEqual(self.evaluate(s.variables[0]), [[4, 4]])
+    self.assertAllEqual(self.evaluate(s.variables[1]), [[5, 5], [6, 6]])
+    self.assertAllEqual(self.evaluate(s.variables[2]), [[7, 7]])
+
+  def test_assign_add(self):
+    v0 = variables_lib.Variable([[0, 0]])
+    v1 = variables_lib.Variable([[1, 1], [2, 2]])
+    v2 = variables_lib.Variable([[3, 3]])
+    s = sharded_variable.ShardedVariable([v0, v1, v2])
+    s.assign_add([[1, 1], [1, 1], [2, 2], [2, 2]])
+    self.assertAllEqual(self.evaluate(s.variables[0]), [[1, 1]])
+    self.assertAllEqual(self.evaluate(s.variables[1]), [[2, 2], [4, 4]])
+    self.assertAllEqual(self.evaluate(s.variables[2]), [[5, 5]])
+
+  def test_assign_sub(self):
+    v0 = variables_lib.Variable([[0, 0]])
+    v1 = variables_lib.Variable([[1, 1], [2, 2]])
+    v2 = variables_lib.Variable([[3, 3]])
+    s = sharded_variable.ShardedVariable([v0, v1, v2])
+    s.assign_sub([[0, 0], [1, 1], [1, 1], [3, 3]])
+    self.assertAllEqual(self.evaluate(s.variables[0]), [[0, 0]])
+    self.assertAllEqual(self.evaluate(s.variables[1]), [[0, 0], [1, 1]])
+    self.assertAllEqual(self.evaluate(s.variables[2]), [[0, 0]])
+
+  def test_convert_to_tensor(self):
+    v0 = variables_lib.Variable([[0, 0]])
+    v1 = variables_lib.Variable([[1, 1], [2, 2]])
+    v2 = variables_lib.Variable([[3, 3]])
+    s = sharded_variable.ShardedVariable([v0, v1, v2])
+    t = ops.convert_to_tensor(s)
+    self.assertAllEqual(t, [[0, 0], [1, 1], [2, 2], [3, 3]])
+
   def test_save_restore(self):
     fname = os.path.join(self.get_temp_dir(), 'checkpoint')
     variables = [
diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index c131892..c2aa68a 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -454,8 +454,7 @@
     self.assertAllEqual(expected_result, run(input_iterator))
     self.assertAllEqual((0.,), w.read_value())
 
-  # TODO(b/140633529): Re-enable the test.
-  def disable_test_experimental_run_output_on_device(self, enable_packed_var):
+  def test_experimental_run_output_on_device(self, enable_packed_var):
     strategy = get_tpu_strategy(enable_packed_var)
 
     def computation(x):
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index bcbada7..68f902a 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -37,7 +37,6 @@
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.saved_model import save_context
 from tensorflow.python.training.saving import saveable_object
-from tensorflow.python.training.saving import saveable_object_util
 from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.types import core
 from tensorflow.python.util.tf_export import tf_export
@@ -435,7 +434,7 @@
             self.traceback == o.traceback and self.type == o.type)
 
   def __hash__(self):
-    return hash((self.name, self.graph, self.traceback, self.type))
+    return hash((self.name, self.graph, tuple(self.traceback), self.type))
 
 
 class DistributedVariable(DistributedDelegate, variables_lib.Variable,
@@ -952,6 +951,13 @@
     return obj_map, resource_map
 
 
+# We extend from `saveable_object.SaveableObject` instead of
+# `saveable_object_util.ResourceVariableSaveable` since we need to read the
+# value of ONREAD variables when saving. `SaveableObject` provides a way to
+# specify the function to run to get the value of the variable or tensor at
+# saving time. We can use this for both ON_READ and ON_WRITE variables.
+# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic
+# if possible.
 class _DistributedVariableSaveable(saveable_object.SaveableObject):
   """Class for defining how to restore a DistributedVariable."""
 
@@ -971,26 +977,21 @@
         self._distributed_variable, tensor)
 
 
-class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable):
+class _MirroredSaveable(saveable_object.SaveableObject):
   """Class for defining how to restore a MirroredVariable."""
 
   def __init__(self, mirrored_variable, primary_variable, name):
     self._mirrored_variable = mirrored_variable
-    super(_MirroredSaveable, self).__init__(primary_variable, "", name)
+    tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable,
+                                                     primary_variable,
+                                                     name)
+    super(_MirroredSaveable, self).__init__(tensor, spec, name)
 
   def restore(self, restored_tensors, restored_shapes):
     """Restore the same value into all variables."""
     tensor, = restored_tensors
-    packed_var = self._mirrored_variable._packed_variable  # pylint: disable=protected-access
-    if packed_var is not None:
-      return control_flow_ops.group(
-          tuple(
-              values_util.assign_on_device(d, packed_var, tensor)
-              for d in packed_var.devices))
-    return control_flow_ops.group(
-        tuple(
-            values_util.assign_on_device(v.device, v, tensor)
-            for v in self._mirrored_variable.values))
+    return values_util.get_on_write_restore_ops(self._mirrored_variable,
+                                                tensor)
 
 
 class MirroredVariable(DistributedVariable, Mirrored):
@@ -1027,8 +1028,6 @@
     return super(MirroredVariable, self).scatter_update(*args, **kwargs)
 
   def _get_cross_replica(self):
-    if values_util.is_saving_non_distributed():
-      return self._primary.read_value()
     # Return identity, to avoid directly exposing the variable to the user and
     # allowing it to be modified by mistake.
     return array_ops.identity(Mirrored._get_cross_replica(self))
@@ -1074,38 +1073,17 @@
 
   def __init__(self, sync_on_read_variable, name):
     self._sync_on_read_variable = sync_on_read_variable
+    tensor, spec = values_util.get_on_read_saveable(
+        sync_on_read_variable, sync_on_read_variable._primary, name)
 
-    # We use a callable so that we don't have to evaluate this expression
-    # in the case where we are trying to restore instead of save.
-    def tensor():
-      strategy = sync_on_read_variable._distribute_strategy  # pylint: disable=protected-access
-      return strategy.extended.read_var(sync_on_read_variable)
-
-    spec = saveable_object.SaveSpec(
-        tensor=tensor,
-        slice_spec="",
-        name=name,
-        dtype=sync_on_read_variable.dtype,
-        device=sync_on_read_variable._primary.device)  # pylint: disable=protected-access
-
-    super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name)
+    super(_SyncOnReadSaveable, self).__init__(tensor, spec, name)
 
   def restore(self, restored_tensors, restored_shapes):
     """Restore the same value into all variables."""
-    # To preserve the sum across save and restore, we have to divide the
-    # total across all devices when restoring a variable that was summed
-    # when saving.
     tensor, = restored_tensors
-    if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM:
-      # pylint: disable=protected-access
-      strategy = self._sync_on_read_variable._distribute_strategy
-      tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
-                             self._sync_on_read_variable.dtype)
-      # pylint: enable=protected-access
-    return control_flow_ops.group(
-        tuple(
-            values_util.assign_on_device(v.device, v, tensor)
-            for v in self._sync_on_read_variable.values))
+    return values_util.get_on_read_restore_ops(
+        self._sync_on_read_variable, tensor,
+        self._sync_on_read_variable.aggregation)
 
 
 class SyncOnReadVariable(DistributedVariable):
@@ -1206,8 +1184,6 @@
         return self._get_on_device_or_primary().value()
 
   def _get_cross_replica(self):
-    if values_util.is_saving_non_distributed():
-      return self._primary.read_value()
     if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
       # Consider returning a tensor value here to make the return value of
       # _get_cross_replica consistent.
@@ -1432,35 +1408,11 @@
 
   def get_saveable(self, var, primary_var, name):
     """Create a saveable object for the given variable."""
-
-    # We use a callable so that we don't have to evaluate this expression
-    # in the case where we are trying to restore instead of save.
-    def tensor():
-      strategy = var.distribute_strategy
-      return strategy.extended.read_var(var)
-
-    spec = saveable_object.SaveSpec(
-        tensor=tensor,
-        slice_spec="",
-        name=name,
-        dtype=var.dtype,
-        device=primary_var.device)
-
-    return tensor, [spec]
+    return values_util.get_on_read_saveable(var, primary_var, name)
 
   def get_restore_ops(self, var, tensor):
     """Restore the same value into all variables."""
-    # To preserve the sum across save and restore, we have to divide the
-    # total across all devices when restoring a variable that was summed
-    # when saving.
-    if self._aggregation == vs.VariableAggregation.SUM:
-      strategy = var._distribute_strategy  # pylint: disable=protected-access
-      num_replicas_in_sync = strategy.num_replicas_in_sync
-      tensor = math_ops.cast(tensor / num_replicas_in_sync, var.dtype)
-    return control_flow_ops.group(
-        tuple(
-            values_util.assign_on_device(v.device, v, tensor)
-            for v in var.values))
+    return values_util.get_on_read_restore_ops(var, tensor, self._aggregation)
 
 
 class AutoPolicy(VariablePolicy):
@@ -1484,7 +1436,7 @@
   def _get_cross_replica(self, var):
     # Return identity, to avoid directly exposing the variable to the user and
     # allowing it to be modified by mistake.
-    return array_ops.identity(Mirrored._get_cross_replica(var))  # pylint: disable=protected-access
+    return array_ops.identity(var._get_on_device_or_primary())  # pylint: disable=protected-access
 
   def _update_replica(self, var, update_fn, value, **kwargs):
     return update_fn(var._get_on_device_or_primary(), value, **kwargs)  # pylint: disable=protected-access
@@ -1545,14 +1497,11 @@
                                       name=name)
 
   def get_saveable(self, var, primary_var, name):
-    del var, name
-    return primary_var, ""
+    """Saveable ops for AUTO variables."""
+    return values_util.get_on_write_saveable(var, primary_var, name)
 
   def get_restore_ops(self, var, tensor):
-    return control_flow_ops.group(
-        tuple(
-            values_util.assign_on_device(v.device, v, tensor)
-            for v in var.values))
+    return values_util.get_on_write_restore_ops(var, tensor)
 
 
 class OnWritePolicy(AutoPolicy):
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index e4926f2..02a9926 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -81,15 +81,23 @@
   return values_lib.Mirrored(v)
 
 
-def _make_mirrored():
+def _make_mirrored(distribution=None):
   v = []
-  devices = ["/device:GPU:0", "/device:CPU:0"]
+  if distribution:
+    devices = distribution.extended.worker_devices
+  else:
+    devices = ["/device:GPU:0", "/device:CPU:0"]
   for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
     with ops.device(d):
-      v.append(variable_scope.get_variable(
-          name=n, initializer=init, use_resource=True))
-  mirrored = values_lib.MirroredVariable(
-      None, v, variable_scope.VariableAggregation.SUM)
+      v.append(
+          variable_scope.get_variable(
+              name=n, initializer=init, use_resource=True))
+
+  if (distribution is not None) and isinstance(distribution, _TPU_STRATEGIES):
+    var_cls = tpu_values.TPUMirroredVariable
+  else:
+    var_cls = values_lib.MirroredVariable
+  mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM)
   return mirrored
 
 
@@ -409,7 +417,7 @@
             strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
             strategy_combinations.tpu_strategy,
             strategy_combinations.tpu_strategy_packed_var,
-            strategy_combinations.central_storage_strategy_with_two_gpus,
+            strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
             strategy_combinations.multi_worker_mirrored_2x1_cpu,
             strategy_combinations.multi_worker_mirrored_2x1_gpu,
             strategy_combinations.multi_worker_mirrored_2x2_gpu
@@ -423,7 +431,8 @@
             variables_lib.VariableAggregation.SUM,
             variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
         ],
-        mode=["graph", "eager"]))
+        mode=["graph", "eager"],
+        use_var_policy=[True, False]))
 class DistributedVariableTest(test.TestCase, parameterized.TestCase):
 
   def testExtendsVariable(self, distribution, synchronization, aggregation):
@@ -554,7 +563,10 @@
         self.assertIsInstance(v2.get(), type(v1.get()))
         self.assertNotEqual(id(v1.get()), id(v2.get()))
       else:
-        self.assertEqual(v1._policy, v2._policy)  # pylint: disable=protected-access
+        if v1._policy:
+          self.assertNotEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
+        else:
+          self.assertEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
         self.assertEqual(len(v1.values), len(v2.values))
         for (v1v, v2v) in zip(v1.values, v2.values):
           self.assertEqual(v1v.device, v2v.device)
@@ -900,6 +912,9 @@
     self.assertEqual(v.dtype, mirrored.dtype)
     self.assertEqual(v.shape, mirrored.shape)
 
+
+class MirroredVariableSaveRestoreTest(test.TestCase, parameterized.TestCase):
+
   def _assign_mirrored(self, v, new):
     for var, n in zip(v.values, new):
       self.evaluate(var.assign(n))
@@ -914,37 +929,10 @@
     save_path, _ = self._save_return_saver(sess, var)
     return save_path
 
-  @test_util.run_in_graph_and_eager_modes(config=config)
-  def testSaveAndRestoreMirroredOneGraph(self):
-    if context.num_gpus() < 1 and context.executing_eagerly():
-      # Graph mode can work without GPU because the Placer "moves" the
-      # variable to a CPU. In other words, if there is no GPU available, but
-      # user requested to create a variable on GPU, Placer will ignore the
-      # user request and assign the VarHandleOp to CPU. This requires
-      # soft_placement, which is on by default.
-      self.skipTest("A GPU is not available for this test in eager mode.")
-
-    with self.cached_session(config=self.config) as sess:
-      mirrored = _make_mirrored()
-      v = mirrored.values
-
-      # Overwrite the initial values.
-      self._assign_mirrored(mirrored, [3., 4.])
-
-      # Saves the current value of v[0], 3.
-      save_path, saver = self._save_return_saver(sess, mirrored)
-
-      # Change the values between save and restore.
-      self._assign_mirrored(mirrored, [5., 6.])
-
-      # Restores the saved value of 3. to both variables.
-      saver.restore(sess, save_path)
-      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
-
-  def _save_mirrored(self):
+  def _save_mirrored(self, distribution):
     """Save variables with mirroring, returns save_path."""
     with self.session(graph=ops.Graph()) as sess:
-      mirrored = _make_mirrored()
+      mirrored = _make_mirrored(distribution)
 
       # Overwrite the initial values.
       self._assign_mirrored(mirrored, [3., 4.])
@@ -986,10 +974,10 @@
       saver.restore(sess, save_path)
       self.assertEqual(3., self.evaluate(var))
 
-  def _restore_mirrored(self, save_path):
+  def _restore_mirrored(self, save_path, distribution):
     """Restore to variables with mirroring in a fresh graph."""
     with self.session(graph=ops.Graph()) as sess:
-      mirrored = _make_mirrored()
+      mirrored = _make_mirrored(distribution)
       v = mirrored.values
 
       # Overwrite the initial values.
@@ -1000,8 +988,27 @@
       saver.restore(sess, save_path)
       self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
 
-  @test_util.run_in_graph_and_eager_modes(config=config)
-  def testSaveMirroredRestoreMirrored(self):
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testSaveAndRestoreMirroredOneGraph(self, distribution):
+    with self.cached_session() as sess:
+      mirrored = _make_mirrored(distribution)
+      v = mirrored  .values
+
+      # Overwrite the initial values.
+      self._assign_mirrored(mirrored, [3., 4.])
+
+      # Saves the current value of v[0], 3.
+      save_path, saver = self._save_return_saver(sess, mirrored)
+
+      # Change the values between save and restore.
+      self._assign_mirrored(mirrored, [5., 6.])
+
+      # Restores the saved value of 3. to both variables.
+      saver.restore(sess, save_path)
+      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
+
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testSaveMirroredRestoreMirrored(self, distribution):
     if context.num_gpus() < 1 and context.executing_eagerly():
       # Graph mode can work without GPU because the Placer "moves" the
       # variable to a CPU. In other words, if there is no GPU available, but
@@ -1010,11 +1017,11 @@
       # soft_placement, which is on by default.
       self.skipTest("A GPU is not available for this test in eager mode.")
 
-    save_path = self._save_mirrored()
-    self._restore_mirrored(save_path)
+    save_path = self._save_mirrored(distribution)
+    self._restore_mirrored(save_path, distribution)
 
-  @test_util.run_in_graph_and_eager_modes(config=config)
-  def testSaveMirroredRestoreNormal(self):
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testSaveMirroredRestoreNormal(self, distribution):
     if context.num_gpus() < 1 and context.executing_eagerly():
       # Graph mode can work without GPU because the Placer "moves" the
       # variable to a CPU. In other words, if there is no GPU available, but
@@ -1023,11 +1030,11 @@
       # soft_placement, which is on by default.
       self.skipTest("A GPU is not available for this test in eager mode.")
 
-    save_path = self._save_mirrored()
+    save_path = self._save_mirrored(distribution)
     self._restore_normal(save_path)
 
-  @test_util.run_in_graph_and_eager_modes(config=config)
-  def testSaveNormalRestoreMirrored(self):
+  @combinations.generate(mirrored_and_tpu_strategy_combinations())
+  def testSaveNormalRestoreMirrored(self, distribution):
     if context.num_gpus() < 1 and context.executing_eagerly():
       # Graph mode can work without GPU because the Placer "moves" the
       # variable to a CPU. In other words, if there is no GPU available, but
@@ -1037,7 +1044,7 @@
       self.skipTest("A GPU is not available for this test in eager mode.")
 
     save_path = self._save_normal()
-    self._restore_mirrored(save_path)
+    self._restore_mirrored(save_path, distribution)
 
 
 _TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
diff --git a/tensorflow/python/distribute/values_util.py b/tensorflow/python/distribute/values_util.py
index 535351e..c9bcf8f 100644
--- a/tensorflow/python/distribute/values_util.py
+++ b/tensorflow/python/distribute/values_util.py
@@ -28,6 +28,72 @@
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.saved_model import save_context
 from tensorflow.python.saved_model import save_options
+from tensorflow.python.training.saving import saveable_object
+
+
+def get_on_write_saveable(var, primary_var, name):
+  """Return saveable spec for AUTO and ON_WRITE variables."""
+  # We use a callable so that we don't have to evaluate this expression
+  # in the case where we are trying to restore instead of save.
+  def tensor():
+    strategy = var.distribute_strategy
+    return strategy.extended.read_var(var)
+
+  spec = saveable_object.SaveSpec(
+      tensor=tensor,
+      slice_spec="",
+      name=name,
+      dtype=var.dtype,
+      device=primary_var.device)
+
+  return tensor, [spec]
+
+
+def get_on_write_restore_ops(var, tensor):
+  """Return restore ops for AUTO and ON_WRITE variables."""
+  packed_var = var._packed_variable  # pylint: disable=protected-access
+  if packed_var is not None:
+    return control_flow_ops.group(
+        tuple(
+            assign_on_device(d, packed_var, tensor)
+            for d in packed_var.devices))
+  return control_flow_ops.group(
+      tuple(
+          assign_on_device(v.device, v, tensor)
+          for v in var.values))
+
+
+def get_on_read_saveable(var, primary_var, name):
+  """Return saveables for ON_READ variable."""
+
+  # We use a callable so that we don't have to evaluate this expression
+  # in the case where we are trying to restore instead of save.
+  def tensor():
+    return var._get_cross_replica()  # pylint: disable=protected-access
+
+  spec = saveable_object.SaveSpec(
+      tensor=tensor,
+      slice_spec="",
+      name=name,
+      dtype=var.dtype,
+      device=primary_var.device)
+
+  return tensor, [spec]
+
+
+def get_on_read_restore_ops(var, tensor, aggregation):
+  """Return restore ops for ON_READ variables."""
+  # To preserve the sum across save and restore, we have to divide the
+  # total across all devices when restoring a variable that was summed
+  # when saving.
+  if aggregation == vs.VariableAggregation.SUM:
+    strategy = var.distribute_strategy
+    tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
+                           var.dtype)
+  return control_flow_ops.group(
+      tuple(
+          assign_on_device(v.device, v, tensor)
+          for v in var.values))
 
 
 # Utility function that indicates if you are in an UpdateContext when running
diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py
index a8605a3..ba77384 100644
--- a/tensorflow/python/distribute/vars_test.py
+++ b/tensorflow/python/distribute/vars_test.py
@@ -20,10 +20,11 @@
 
 import itertools
 
+import uuid
 from absl.testing import parameterized
 
 from tensorflow.python.distribute import combinations
-from tensorflow.python.distribute import distribution_strategy_context
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.distribute import values
@@ -41,6 +42,8 @@
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.tpu import tpu_strategy_util
+from tensorflow.python.training import checkpoint_management as ckpt_manager
+from tensorflow.python.training.tracking import util as trackable_utils
 
 
 _TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
@@ -78,22 +81,6 @@
 
 class OnWriteVariableSync(test.TestCase, parameterized.TestCase):
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=[
-              strategy_combinations.mirrored_strategy_with_one_gpu,
-          ],
-          mode=["graph"]))
-  def testFetchAMirroredVariable(self, distribution):
-    with self.session(graph=ops.Graph()) as sess, distribution.scope():
-      with ops.device("/device:GPU:0"):
-        v = variable_scope.get_variable(
-            name="v", initializer=1., use_resource=True)
-      mirrored = values.MirroredVariable(
-          distribution, (v,), variable_scope.VariableAggregation.MEAN)
-      sess.run(variables_lib.global_variables_initializer())
-      sess.run({"complicated": mirrored})
-
   @combinations.generate(strategy_and_run_tf_function_combinations())
   def testAssign(self, distribution, experimental_run_tf_function):
 
@@ -330,7 +317,7 @@
 
     @def_function.function
     def assign():
-      ctx = distribution_strategy_context.get_replica_context()
+      ctx = ds_context.get_replica_context()
       return v.assign(ctx.replica_id_in_sync_group)
 
     # disallow assign() with distributed value in replica context.
@@ -402,7 +389,7 @@
 
     @def_function.function
     def assign():
-      ctx = distribution_strategy_context.get_replica_context()
+      ctx = ds_context.get_replica_context()
       replica_id = ctx.replica_id_in_sync_group
       return v.assign(math_ops.cast(replica_id, dtypes.float32))
     per_replica_results = self.evaluate(distribution.experimental_local_results(
@@ -458,6 +445,60 @@
         distribution.experimental_local_results(distribution.run(add)))
     self.assertAllEqual([2, 2], per_replica_results)
 
+  @combinations.generate(
+      combinations.combine(
+          strategy=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+              strategy_combinations.multi_worker_mirrored_2x1_cpu,
+              strategy_combinations.multi_worker_mirrored_2x1_gpu,
+          ],
+          mode=["eager"],
+          use_var_policy=[True, False]))
+  def testSaveAndRestoreOnWrite(self, strategy):
+    aggregation = [
+        variable_scope.VariableAggregation.NONE,
+        variable_scope.VariableAggregation.ONLY_FIRST_REPLICA,
+        variable_scope.VariableAggregation.SUM,
+        variable_scope.VariableAggregation.MEAN
+    ]
+    for agg in aggregation:
+      v_normal_restore = variables_lib.Variable(1.0)
+      v_normal_save = variables_lib.Variable(3.0)
+      with strategy.scope():
+        v_on_write = variables_lib.Variable(2.0, aggregation=agg)
+
+        # Save ONWRITE Restore ONWRITE
+        # Save
+        ckpt = trackable_utils.Checkpoint(var=v_on_write)
+        manager = ckpt_manager.CheckpointManager(
+            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
+        manager.save()
+        # Restore
+        ckpt.restore(manager.latest_checkpoint)
+        self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
+        self.assertEqual(2.0, self.evaluate(v_on_write.read_value()))
+
+        # Save Mirrored Restore Normal
+        # We've already saved Mirrored, so we only need to restore normal
+        ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
+        ckpt_normal.restore(manager.latest_checkpoint)
+        self.assertEqual(2.0, self.evaluate(v_on_write._values[0]))
+        self.assertEqual(2.0, self.evaluate(v_normal_restore.read_value()))
+
+        # Save Normal Restore Mirrored
+        # Save
+        ckpt = trackable_utils.Checkpoint(var=v_normal_save)
+        manager_2 = ckpt_manager.CheckpointManager(
+            ckpt, "/tmp/ckptckpt_" + str(uuid.uuid4()), max_to_keep=None)
+        manager_2.save()
+        # Restore
+        ckpt_on_write = trackable_utils.Checkpoint(var=v_on_write)
+        ckpt_on_write.restore(manager_2.latest_checkpoint)
+        self.assertEqual(3.0, self.evaluate(v_on_write._values[0]))
+        self.assertEqual(3.0, self.evaluate(v_on_write.read_value()))
+
 
 @combinations.generate(
     combinations.combine(
@@ -468,7 +509,7 @@
         use_var_policy=[True, False]))
 class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase):
 
-  def testScatterSub(self, distribution, use_var_policy):
+  def testScatterSub(self, distribution):
     with distribution.scope():
       v = variables_lib.Variable(
           [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN)
@@ -476,7 +517,7 @@
 
     @def_function.function
     def scatter_sub():
-      ctx = distribution_strategy_context.get_replica_context()
+      ctx = ds_context.get_replica_context()
       replica_id = ctx.replica_id_in_sync_group
       value = indexed_slices.IndexedSlices(
           values=array_ops.stack([
@@ -492,7 +533,7 @@
             distribution.run(scatter_sub)))
     self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results)
 
-  def testScatterAdd(self, distribution, use_var_policy):
+  def testScatterAdd(self, distribution):
     with distribution.scope():
       v = variables_lib.Variable(
           [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
@@ -500,7 +541,7 @@
 
     @def_function.function
     def scatter_add():
-      ctx = distribution_strategy_context.get_replica_context()
+      ctx = ds_context.get_replica_context()
       replica_id = ctx.replica_id_in_sync_group
       value = indexed_slices.IndexedSlices(
           values=array_ops.stack([replica_id, replica_id + 1]),
@@ -513,7 +554,7 @@
             distribution.run(scatter_add)))
     self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results)
 
-  def testScatterDiv(self, distribution, use_var_policy):
+  def testScatterDiv(self, distribution):
     with distribution.scope():
       v = variables_lib.Variable(
           [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM)
@@ -521,7 +562,7 @@
 
     @def_function.function
     def scatter_div():
-      ctx = distribution_strategy_context.get_replica_context()
+      ctx = ds_context.get_replica_context()
       replica_id = ctx.replica_id_in_sync_group
       value = indexed_slices.IndexedSlices(
           values=array_ops.reshape(replica_id + 2, [1]),
@@ -534,7 +575,7 @@
             distribution.run(scatter_div)))
     self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results)
 
-  def testScatterMul(self, distribution, use_var_policy):
+  def testScatterMul(self, distribution):
     with distribution.scope():
       v = variables_lib.Variable(
           [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN)
@@ -542,7 +583,7 @@
 
     @def_function.function
     def scatter_mul():
-      ctx = distribution_strategy_context.get_replica_context()
+      ctx = ds_context.get_replica_context()
       replica_id = ctx.replica_id_in_sync_group
       value = indexed_slices.IndexedSlices(
           values=array_ops.reshape(
@@ -556,7 +597,7 @@
             distribution.run(scatter_mul)))
     self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results)
 
-  def testScatterMin(self, distribution, use_var_policy):
+  def testScatterMin(self, distribution):
     with distribution.scope():
       v1 = variables_lib.Variable(
           [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM)
@@ -583,7 +624,7 @@
             distribution.run(scatter_min, args=(v2,))))
     self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results)
 
-  def testScatterMax(self, distribution, use_var_policy):
+  def testScatterMax(self, distribution):
     with distribution.scope():
       v1 = variables_lib.Variable(
           [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
@@ -610,7 +651,7 @@
             distribution.run(scatter_max, args=(v2,))))
     self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results)
 
-  def testScatterUpdate(self, distribution, use_var_policy):
+  def testScatterUpdate(self, distribution):
     with distribution.scope():
       v1 = variables_lib.Variable(
           [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM)
@@ -637,7 +678,7 @@
             distribution.run(scatter_update, args=(v2,))))
     self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results)
 
-  def testScatterOpsInCrossReplicaContext(self, distribution, use_var_policy):
+  def testScatterOpsInCrossReplicaContext(self, distribution):
     with distribution.scope():
       v1 = variables_lib.Variable(
           [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM)
@@ -659,8 +700,7 @@
 class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase):
 
   @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssign(self, distribution, experimental_run_tf_function,
-                 use_var_policy):
+  def testAssign(self, distribution, experimental_run_tf_function):
 
     def assign(fn, v, update_value, cross_replica):
       update_fn = lambda: getattr(v, fn)(update_value)
@@ -702,8 +742,7 @@
                             self.evaluate(array_ops.ones_like(component)))
 
   @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignOnReadVar(self, distribution, experimental_run_tf_function,
-                          use_var_policy):
+  def testAssignOnReadVar(self, distribution, experimental_run_tf_function):
 
     with distribution.scope():
       v_to_assign = variable_scope.variable(
@@ -764,8 +803,7 @@
                               self.evaluate(component.read_value()))
 
   @combinations.generate(strategy_and_run_tf_function_combinations())
-  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function,
-                              use_var_policy):
+  def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function):
 
     if isinstance(distribution, _TPU_STRATEGIES):
       self.skipTest("Assigning PerReplica values is not supported. See"
@@ -822,8 +860,7 @@
 
   @combinations.generate(strategy_and_run_tf_function_combinations())
   def testAssignDtypeConversion(self, distribution,
-                                experimental_run_tf_function,
-                                use_var_policy):
+                                experimental_run_tf_function):
 
     def assign(fn, v, update_value, cross_replica):
       update_fn = lambda: getattr(v, fn)(update_value)
@@ -865,7 +902,7 @@
                             self.evaluate(array_ops.ones_like(component)))
 
   @combinations.generate(strategy_with_var_policy())
-  def testAssignWithAggregationSum(self, distribution, use_var_policy):
+  def testAssignWithAggregationSum(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
           0.,
@@ -878,7 +915,7 @@
                           self.evaluate(array_ops.ones_like(component)))
 
   @combinations.generate(strategy_with_var_policy())
-  def testAssignAddSubWithAggregationSum(self, distribution, use_var_policy):
+  def testAssignAddSubWithAggregationSum(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
           0.,
@@ -894,8 +931,7 @@
 
   @combinations.generate(strategy_and_run_tf_function_combinations())
   def testReadValueInReplicaContext(self, distribution,
-                                    experimental_run_tf_function,
-                                    use_var_policy):
+                                    experimental_run_tf_function):
     aggregations = [
         variables_lib.VariableAggregation.NONE,
         variables_lib.VariableAggregation.SUM,
@@ -921,8 +957,7 @@
 
   @combinations.generate(strategy_and_run_tf_function_combinations())
   def testReadValueInCrossReplicaContext(self, distribution,
-                                         experimental_run_tf_function,
-                                         use_var_policy):
+                                         experimental_run_tf_function):
     aggregations = [
         variables_lib.VariableAggregation.SUM,
         variables_lib.VariableAggregation.MEAN,
@@ -940,7 +975,7 @@
       self.evaluate(variables_lib.global_variables_initializer())
 
       def assign(v=v):
-        ctx = distribution_strategy_context.get_replica_context()
+        ctx = ds_context.get_replica_context()
         replica_id = ctx.replica_id_in_sync_group
         return v.assign(math_ops.cast(replica_id, dtypes.float32))
 
@@ -967,8 +1002,7 @@
   # respected on GPUs.
   @combinations.generate(strategy_and_run_tf_function_combinations())
   def disable_testAllReduce(self, distribution,
-                            experimental_run_tf_function,
-                            use_var_policy):
+                            experimental_run_tf_function):
     with distribution.scope():
       v = variable_scope.variable(
           2.,
@@ -977,7 +1011,7 @@
     self.evaluate(variables_lib.global_variables_initializer())
 
     def all_reduce():
-      ctx = distribution_strategy_context.get_replica_context()
+      ctx = ds_context.get_replica_context()
       replica_id = ctx.replica_id_in_sync_group
       return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
                                                       dtypes.float32)
@@ -995,8 +1029,7 @@
 
   @combinations.generate(strategy_and_run_tf_function_combinations())
   def testAssignPerReplicaBeforeRead(self, distribution,
-                                     experimental_run_tf_function,
-                                     use_var_policy):
+                                     experimental_run_tf_function):
     aggregations = [
         variables_lib.VariableAggregation.SUM,
         variables_lib.VariableAggregation.MEAN,
@@ -1011,7 +1044,7 @@
       self.evaluate(variables_lib.global_variables_initializer())
 
       def assign(var=v):
-        ctx = distribution_strategy_context.get_replica_context()
+        ctx = ds_context.get_replica_context()
         replica_id = ctx.replica_id_in_sync_group
         return var.assign(math_ops.cast(replica_id, dtypes.float32))
 
@@ -1026,8 +1059,7 @@
       self.assertEqual(per_replica_results, tuple(expected_result))
 
   @combinations.generate(strategy_with_var_policy())
-  def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution,
-                                                            use_var_policy):
+  def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
     with distribution.scope():
       v = variable_scope.variable(
           0.,
@@ -1039,8 +1071,7 @@
       self.evaluate(v.read_value())
 
   @combinations.generate(strategy_with_var_policy())
-  def testInitializedToSameValueInsideEagerRun(self, distribution,
-                                               use_var_policy):
+  def testInitializedToSameValueInsideEagerRun(self, distribution):
     if not context.executing_eagerly(): self.skipTest("eager only")
 
     v = [None]
@@ -1060,7 +1091,7 @@
     self.assertAllEqual(vals[0], vals[1])
 
   @combinations.generate(strategy_with_var_policy())
-  def testOperatorOverride(self, distribution, use_var_policy):
+  def testOperatorOverride(self, distribution):
 
     with distribution.scope():
       v = variable_scope.variable(
@@ -1071,7 +1102,7 @@
 
       @def_function.function
       def assign():
-        ctx = distribution_strategy_context.get_replica_context()
+        ctx = ds_context.get_replica_context()
         replica_id = ctx.replica_id_in_sync_group
         return v.assign(math_ops.cast(replica_id, dtypes.float32))
 
@@ -1088,6 +1119,73 @@
           distribution.experimental_local_results(distribution.run(add)))
       self.assertAllEqual([1, 2], per_replica_results)
 
+  @combinations.generate(
+      combinations.combine(
+          strategy=[
+              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+              strategy_combinations.tpu_strategy,
+              strategy_combinations.tpu_strategy_packed_var,
+              strategy_combinations.multi_worker_mirrored_2x1_cpu,
+              strategy_combinations.multi_worker_mirrored_2x1_gpu,
+          ],
+          mode=["eager"],
+          use_var_policy=[True, False]))
+  def testSaveAndRestoreOnRead(self, strategy):
+    aggregation = [variable_scope.VariableAggregation.SUM,
+                   variable_scope.VariableAggregation.MEAN]
+    for agg in aggregation:
+      v_normal_restore = variables_lib.Variable(1.0)
+      v_normal_save = variables_lib.Variable(2.0)
+
+      with strategy.scope():
+        v_on_read = variables_lib.Variable(
+            1.0, synchronization=variable_scope.VariableSynchronization.ON_READ,
+            aggregation=agg)
+
+        @def_function.function
+        def assign_fn():
+          cluster_resolver = strategy.cluster_resolver
+          replica_ctx = ds_context.get_replica_context()
+          if ((cluster_resolver and cluster_resolver.task_type == "worker") or
+              math_ops.equal(replica_ctx.replica_id_in_sync_group,
+                             constant_op.constant(1))):
+            v_on_read.assign(3.)  # pylint:disable=cell-var-from-loop
+          else:
+            v_on_read.assign(4.)  # pylint:disable=cell-var-from-loop
+
+        strategy.run(assign_fn)
+
+        # Save ONREAD, restore ONREAD
+        # Saves v[0] + v[1] = 7 for SUM and 3.5 for MEAN.
+        ckpt = trackable_utils.Checkpoint(var=v_on_read)
+        manager = ckpt_manager.CheckpointManager(
+            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
+        manager.save()
+        # Restores a value of 7/2 = 3.5 for SUM and 3.5 for MEAN.
+        ckpt.restore(manager.latest_checkpoint)
+        self.assertEqual(3.5, self.evaluate(v_on_read._values[0]))
+
+        # Save ONREAD, restore normal
+        ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore)
+        ckpt_normal.restore(manager.latest_checkpoint)
+        if agg == variable_scope.VariableAggregation.SUM:
+          self.assertEqual(7.0, self.evaluate(v_normal_restore.read_value()))
+        else:
+          self.assertEqual(3.5, self.evaluate(v_normal_restore.read_value()))
+
+        # Save normal, restore ONREAD
+        ckpt = trackable_utils.Checkpoint(var=v_normal_save)
+        manager = ckpt_manager.CheckpointManager(
+            ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None)
+        manager.save()
+        # Restores a value of 2/2 = 1.0 for SUM and 2.0 for MEAN.
+        ckpt_on_read = trackable_utils.Checkpoint(var=v_on_read)
+        ckpt_on_read.restore(manager.latest_checkpoint)
+        if agg == variable_scope.VariableAggregation.SUM:
+          self.assertEqual(1.0, self.evaluate(v_on_read._values[0]))
+        else:
+          self.assertEqual(2.0, self.evaluate(v_on_read._values[0]))
+
 
 @combinations.generate(
     combinations.combine(
@@ -1103,7 +1201,7 @@
         use_var_policy=[True, False]))
 class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
 
-  def testScatterSub(self, distribution, aggregation, use_var_policy):
+  def testScatterSub(self, distribution, aggregation):
     with distribution.scope():
       v = variables_lib.Variable(
           [1., 1., 1.],
@@ -1121,7 +1219,7 @@
     with self.assertRaises(NotImplementedError):
       self.evaluate(distribution.run(v.scatter_sub, args=(delta,)))
 
-  def testScatterAdd(self, distribution, aggregation, use_var_policy):
+  def testScatterAdd(self, distribution, aggregation):
     with distribution.scope():
       v = variables_lib.Variable(
           [1., 1., 1.],
@@ -1139,7 +1237,7 @@
     with self.assertRaises(NotImplementedError):
       self.evaluate(distribution.run(v.scatter_add, args=(delta,)))
 
-  def testScatterDiv(self, distribution, aggregation, use_var_policy):
+  def testScatterDiv(self, distribution, aggregation):
     with distribution.scope():
       v = variables_lib.Variable(
           [2., 6., 1.],
@@ -1157,7 +1255,7 @@
     with self.assertRaises(NotImplementedError):
       self.evaluate(distribution.run(v.scatter_div, args=(delta,)))
 
-  def testScatterMul(self, distribution, aggregation, use_var_policy):
+  def testScatterMul(self, distribution, aggregation):
     with distribution.scope():
       v = variables_lib.Variable(
           [2., 1., 1.],
@@ -1175,7 +1273,7 @@
     with self.assertRaises(NotImplementedError):
       self.evaluate(distribution.run(v.scatter_mul, args=(delta,)))
 
-  def testScatterMin(self, distribution, aggregation, use_var_policy):
+  def testScatterMin(self, distribution, aggregation):
     with distribution.scope():
       v = variables_lib.Variable(
           [3., 4., 5.],
@@ -1193,7 +1291,7 @@
     with self.assertRaises(NotImplementedError):
       self.evaluate(distribution.run(v.scatter_min, args=(delta,)))
 
-  def testScatterMax(self, distribution, aggregation, use_var_policy):
+  def testScatterMax(self, distribution, aggregation):
     with distribution.scope():
       v = variables_lib.Variable(
           [3., 4., 5.],
@@ -1211,7 +1309,7 @@
     with self.assertRaises(NotImplementedError):
       self.evaluate(distribution.run(v.scatter_max, args=(delta,)))
 
-  def testScatterUpdate(self, distribution, aggregation, use_var_policy):
+  def testScatterUpdate(self, distribution, aggregation):
     with distribution.scope():
       v = variables_lib.Variable(
           [0., 0., 0.],
@@ -1231,4 +1329,4 @@
 
 
 if __name__ == "__main__":
-  test.main()
+  combinations.main()
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index f087903..2d1b701 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -5,6 +5,9 @@
 
 # buildifier: disable=same-origin-load
 load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
+
+# buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "pybind_extension")
 load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
 load(
     "//tensorflow/tools/test:performance.bzl",
@@ -580,6 +583,16 @@
     ],
 )
 
+pybind_extension(
+    name = "_concrete_function",
+    srcs = ["function.cc"],
+    module_name = "_concrete_function",
+    deps = [
+        "//third_party/python_runtime:headers",
+        "@pybind11",
+    ],
+)
+
 py_library(
     name = "backprop",
     srcs = ["backprop.py"],
diff --git a/tensorflow/python/eager/benchmarks/BUILD b/tensorflow/python/eager/benchmarks/BUILD
new file mode 100644
index 0000000..8e147d5
--- /dev/null
+++ b/tensorflow/python/eager/benchmarks/BUILD
@@ -0,0 +1,21 @@
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+package(
+    default_visibility = ["//tensorflow:internal"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cuda_py_test(
+    name = "kpi_benchmark_test",
+    size = "medium",
+    srcs = ["kpi_benchmark_test.py"],
+    python_version = "PY3",
+    tags = [
+        "no_windows",  #  b/141617449
+        "optonly",
+    ],
+    deps = [
+        "//tensorflow:tensorflow_py_no_contrib",
+        "//tensorflow/python/eager:benchmarks_test_base",
+    ],
+)
diff --git a/tensorflow/python/eager/benchmarks/kpi_benchmark_test.py b/tensorflow/python/eager/benchmarks/kpi_benchmark_test.py
new file mode 100644
index 0000000..22a70e1
--- /dev/null
+++ b/tensorflow/python/eager/benchmarks/kpi_benchmark_test.py
@@ -0,0 +1,121 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""KPI Benchmarks for low-level eager execution primitives.
+
+This is a suite of full end-to-end integration benchmakr for low-level eager
+execution APIs. Also tracks them as KPI Traceme.
+
+To run CPU benchmarks:
+  bazel run -c opt kpi_benchmarks_test -- --benchmarks=.
+
+To run GPU benchmarks:
+  bazel run --config=cuda -c opt --copt="-mavx" kpi_benchmarks_test -- \
+    --benchmarks=.
+
+To run a subset of benchmarks using --benchmarks flag.
+--benchmarks: the list of benchmarks to run. The specified value is interpreted
+as a regular expression and any benchmark whose name contains a partial match
+to the regular expression is executed.
+e.g. --benchmarks=".*matmul*." will run all matmul related benchmarks.
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gc
+import time
+
+import tensorflow as tf
+
+from tensorflow.python.eager import benchmarks_test_base
+from tensorflow.python.eager import context
+from tensorflow.python.profiler import trace
+
+NUM_ITERATIONS = 30000
+
+
+def _run_benchmark(func, num_iters, execution_mode=None):
+  ctx = context.context()
+  with context.execution_mode(execution_mode):
+    # call func to warm up
+    func()
+    if execution_mode == context.ASYNC:
+      ctx.executor.wait()
+    start = time.time()
+    for _ in range(num_iters):
+      func()
+    if execution_mode == context.ASYNC:
+      ctx.executor.wait()
+    end = time.time()
+
+    return end - start
+
+
+class KpiBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
+  """A Collection of KPI benchmarks."""
+
+  def _get_benchmark_name(self):
+    return self._get_name()
+
+  def _run(self, func, num_iters):
+    gc.disable()
+    gc.collect()
+    self.run_report(_run_benchmark, func, num_iters)
+    gc.enable()
+
+  def benchmark_tf_constant_2x2(self):
+    x = [[1., 2.], [3., 4.]]
+
+    def fn():
+      with trace.Trace("tf.constant-2x2"):
+        tf.constant(x)
+
+    self._run(fn, NUM_ITERATIONS)
+
+  def benchmark_tf_convert_to_tensor_2x2(self):
+    x = [[1., 2.], [3., 4.]]
+
+    def fn():
+      with trace.Trace("tf.convert_to_tensor-2x2"):
+        tf.convert_to_tensor(x)
+
+    self._run(fn, NUM_ITERATIONS)
+
+  def benchmark_tf_nn_relu_2x2(self):
+    x = tf.constant([[1., 2.], [3., 4.]])
+
+    def fn():
+      with trace.Trace("tf.nn.relu-2x2"):
+        tf.nn.relu(x)
+
+    self._run(fn, NUM_ITERATIONS)
+
+  def benchmark_tf_function_invocation_identity(self):
+    x = tf.constant([[1., 2.], [3., 4.]])
+
+    @tf.function
+    def identity(x):
+      return x
+
+    def fn():
+      with trace.Trace("tf.function-identity"):
+        identity(x)
+
+    self._run(fn, NUM_ITERATIONS)
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 6150ca1..fd50f78 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -253,32 +253,26 @@
     tensor_b = constant_op.constant([[24, 24], [24, 24]])
     self._benchmark_add(tensor_a, tensor_b)
 
-  @test_util.disable_tfrt("convert_to_tensor not handled")
   def benchmark_create_float_tensor_from_list_CPU(self):
     self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
 
-  @test_util.disable_tfrt("convert_to_tensor not handled")
   def benchmark_create_float_tensor_from_np_array_CPU(self):
     self._benchmark_create_tensor(
         np.array([[3.0]], dtype=np.float32), dtypes.float32.as_datatype_enum,
         CPU)
 
-  @test_util.disable_tfrt("convert_to_tensor not handled")
   def benchmark_create_int32_tensor_from_list_CPU(self):
     self._benchmark_create_tensor([[3]], dtypes.int32.as_datatype_enum, CPU)
 
-  @test_util.disable_tfrt("convert_to_tensor not handled")
   def benchmark_create_int32_tensor_from_np_array_CPU(self):
     self._benchmark_create_tensor(
         np.array([[3]], dtype=np.int32), dtypes.int32.as_datatype_enum, CPU)
 
-  @test_util.disable_tfrt("no gpu support")
   def benchmark_create_float_tensor_from_list_GPU(self):
     if not context.num_gpus():
       return
     self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, GPU)
 
-  @test_util.disable_tfrt("no gpu support")
   def benchmark_create_float_tensor_from_np_array_GPU(self):
     if not context.num_gpus():
       return
@@ -286,14 +280,12 @@
         np.array([[3.0]], dtype=np.float32), dtypes.float32.as_datatype_enum,
         GPU)
 
-  @test_util.disable_tfrt("no gpu support")
   def benchmark_create_int32_tensor_from_list_GPU(self):
     # int32's are kept on host memory even when executing on GPU.
     if not context.num_gpus():
       return
     self._benchmark_create_tensor([[3]], dtypes.int32.as_datatype_enum, GPU)
 
-  @test_util.disable_tfrt("no gpu support")
   def benchmark_create_int32_tensor_from_np_array_GPU(self):
     # int32's are kept on host memory even when executing on GPU.
     if not context.num_gpus():
@@ -301,17 +293,14 @@
     self._benchmark_create_tensor(
         np.array([[3]], dtype=np.int32), dtypes.int32.as_datatype_enum, GPU)
 
-  @test_util.disable_tfrt("strided slice not supported")
   def benchmark_index_tensor_with_literal(self):
     func = lambda: constant_op.constant([3.0])[0]
     self._run(func, 30000)
 
-  @test_util.disable_tfrt("strided slice not supported")
   def benchmark_index_tensor_with_tensor(self):
     func = lambda idx=constant_op.constant(0): constant_op.constant([3.0])[idx]
     self._run(func, 30000)
 
-  @test_util.disable_tfrt("strided slice not supported")
   def benchmark_index_tensor_with_np_array(self):
     func = lambda idx=np.array(0): constant_op.constant([3.0])[idx]
     self._run(func, 30000)
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 9c939fe..567a34d 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -68,9 +68,6 @@
 SYNC = 0
 ASYNC = 1
 
-MIRRORING_NONE = pywrap_tfe.TFE_MIRRORING_NONE
-MIRRORING_ALL = pywrap_tfe.TFE_MIRRORING_ALL
-
 _KEEP_ALIVE_SECS = 600
 
 _python_eager_context_create_counter = monitoring.Counter(
@@ -772,6 +769,26 @@
     self.ensure_initialized()
     pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
 
+  def check_collective_ops_peer_health(self, task):
+    """Check collective peer health.
+
+    This probes each task to see if they're still alive. Note that restarted
+    tasks are considered a different one, and they're considered not healthy.
+
+    This should only be used in multi client multi worker training.
+
+    Args:
+      task: a task string, must be in the format of /job:xxx/replica:0/task:N.
+
+    Raises:
+      tf.errors.UnavailableError: when a peer is down.
+      tf.errors.FailedPreconditionError: when a peer is a different one from the
+        one this task has talked to, e.g. the peer has restarted.
+      tf.errors.InvalidArgumentError: when the task string is invalid.
+    """
+    self.ensure_initialized()
+    pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task)
+
   @property
   def _handle(self):
     if self._context_handle is None:
@@ -1649,27 +1666,6 @@
             self._handle, self._device_policy)
 
   @property
-  def mirroring_policy(self):
-    # Only get the policy from the context if it has already been initialized
-    if self._context_handle is not None:
-      return pywrap_tfe.TFE_ContextGetMirroringPolicy(self._handle)
-
-    return self._mirroring_policy
-
-  @mirroring_policy.setter
-  def mirroring_policy(self, policy):
-    if policy is None:
-      policy = MIRRORING_NONE
-
-    if self._mirroring_policy is None or self._mirroring_policy != policy:
-      self._mirroring_policy = policy
-
-      # Only set the policy if the context has already been initialized
-      if self._context_handle is not None:
-        pywrap_tfe.TFE_ContextSetThreadLocalMirroringPolicy(
-            self._handle, self._mirroring_policy)
-
-  @property
   def lazy_remote_inputs_copy(self):
     return self._lazy_remote_inputs_copy
 
@@ -2104,18 +2100,6 @@
     ctx.device_policy = old_policy
 
 
-@tf_contextlib.contextmanager
-def mirroring_policy(policy):
-  """Context manager for setting mirroring policy for current thread."""
-  ctx = context()
-  old_policy = ctx.mirroring_policy
-  try:
-    ctx.mirroring_policy = policy
-    yield
-  finally:
-    ctx.mirroring_policy = old_policy
-
-
 def set_execution_mode(mode):
   """Sets execution mode for the current thread."""
   context().execution_mode = mode
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index d7f412b..a2840b3 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -431,7 +431,6 @@
     self.assertFalse(switch.is_building_function)
 
   @test_util.run_gpu_only
-  @test_util.disable_tfrt('Resolve not implemented yet.')
   def testInt32GPU(self):
     with ops.device('gpu:0'):
       xent = nn_ops.sparse_softmax_cross_entropy_with_logits(
@@ -485,7 +484,6 @@
       self.assertAllEqual(t.numpy(), 10.0)
 
   @test_util.run_gpu_only
-  @test_util.disable_tfrt('Resolve not implemented yet.')
   def testDevicePlacementEnforcesConsistency(self):
     cpu = context.device('cpu:0')
     gpu = context.device('gpu:0')
@@ -528,7 +526,6 @@
     self.assertEqual(3, result)
 
   @test_util.run_gpu_only
-  @test_util.disable_tfrt('Resolve not implemented yet.')
   def testResourceTensorPlacement(self):
     with context.device('gpu:0'):
       v = resource_variable_ops.ResourceVariable(1.0)
@@ -568,7 +565,7 @@
     context.context().executor.clear_error()
 
   @test_util.run_gpu_only
-  @test_util.disable_tfrt('TensorHandleInterface::Resolve() not implemented.')
+  @test_util.disable_tfrt('Device placement not implemented.')
   def testCopyScope(self):
     constant = constant_op.constant(1.0)
     with ops.device('gpu:0'):
@@ -609,7 +606,6 @@
     async_executor.wait()
 
   @test_util.run_gpu_only
-  @test_util.disable_tfrt('Resolve not implemented yet.')
   def testNumpyForceCPU(self):
     cpu = constant_op.constant([[1., 2.], [3., 4.]])
     c2g = cpu.gpu()
@@ -692,7 +688,7 @@
           attrs=('T', dtypes.int32.as_datatype_enum))[0]
 
   @test_util.run_gpu_only
-  @test_util.disable_tfrt('Resolve not implemented yet.')
+  @test_util.disable_tfrt('Device placement not implemented yet.')
   def testMatMulGPU(self):
     three = constant_op.constant([[3.]]).gpu()
     five = constant_op.constant([[5.]]).gpu()
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 22cd1ce..3199747 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -845,13 +845,14 @@
         # stateless function.
         return self._stateless_fn(*args, **kwds)
     else:
-      _, _, flat_args, flat_kwds = \
+      _, _, _, filtered_flat_args = \
           self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
               *args, **kwds)
       # If we did not create any variables the trace we have is good enough.
-      return self._concrete_stateful_fn._filtered_call(flat_args, flat_kwds)  # pylint: disable=protected-access
+      return self._concrete_stateful_fn._call_flat(
+          filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
 
-    def fn_with_cond(inner_args, inner_kwds, inner_flat_args, inner_flat_kwds):
+    def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):
       """Conditionally runs initialization if it's needed."""
       condition = True
       for wr in self._created_variables:
@@ -900,17 +901,17 @@
           condition,
           lambda: self._stateless_fn(*inner_args, **inner_kwds),
           functools.partial(
-              self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
-              inner_flat_args,
-              inner_flat_kwds))
+              self._concrete_stateful_fn._call_flat,  # pylint: disable=protected-access
+              inner_filtered_flat_args,
+              captured_inputs=self._concrete_stateful_fn.captured_inputs))
 
     # We've created variables and are unable to lift the initialization graphs,
     # so we fall back to initializing with conds while running the function.
-    canon_args, canon_kwds, flat_args, flat_kwds = \
+    canon_args, canon_kwds, _, filtered_flat_args = \
         self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
             *args, **kwds)
-    return function_lib.defun(fn_with_cond)(canon_args, canon_kwds, flat_args,
-                                            flat_kwds)
+    return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
+                                            filtered_flat_args)
 
   @property
   def python_function(self):
diff --git a/tensorflow/python/eager/function.cc b/tensorflow/python/eager/function.cc
new file mode 100644
index 0000000..0fc22d3
--- /dev/null
+++ b/tensorflow/python/eager/function.cc
@@ -0,0 +1,83 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <Python.h>
+
+#include "pybind11/pybind11.h"
+#include "pybind11/stl_bind.h"
+
+struct ConcreteFunction;  // Forward declaration.
+
+// TODO(jlchu): Migrate Python characteristics to C++
+
+namespace tensorflow {
+
+namespace py = pybind11;
+
+struct PyConcreteFunction {
+  PyConcreteFunction() {}
+  py::object _build_call_outputs(py::object result,
+                                 py::object structured_outputs,
+                                 bool _ndarrays_list, bool _ndarray_singleton);
+};
+
+py::object PyConcreteFunction::_build_call_outputs(
+    py::object result, py::object structured_outputs, bool _ndarrays_list,
+    bool _ndarray_singleton) {
+  static const py::module* nest =
+      new py::module(py::module::import("tensorflow.python.util.nest"));
+  // TODO(jlchu): Look into lazy loading of np_arrays module
+  static const py::module* np_arrays = new py::module(
+      py::module::import("tensorflow.python.ops.numpy_ops.np_arrays"));
+
+  if (structured_outputs.is_none()) {
+    return result;
+  }
+
+  // TODO(jlchu): Verify invariant -result = None only if
+  // structured_outputs = None?
+  py::list list_result = (py::list)result;
+
+  if (!list_result.empty()) {
+    if (_ndarrays_list) {
+      py::list ndarr_result(list_result.size());
+      for (int i = 0; i < ndarr_result.size(); ++i) {
+        ndarr_result[i] = np_arrays->attr("tensor_to_ndarray")(list_result[i]);
+      }
+      return ndarr_result;
+    } else if (_ndarray_singleton) {
+      return np_arrays->attr("tensor_to_ndarray")(list_result[0]);
+    }
+  }
+
+  // Replace outputs with results, skipping over any 'None' values.
+  py::list outputs_list = nest->attr("flatten")(structured_outputs, true);
+  int j = 0;
+  for (int i = 0; i < outputs_list.size(); ++i) {
+    if (!outputs_list[i].is_none()) {
+      outputs_list[i] = list_result[j];
+      ++j;
+    }
+  }
+  return nest->attr("pack_sequence_as")(structured_outputs, outputs_list, true);
+}
+
+PYBIND11_MODULE(_concrete_function, m) {
+  py::class_<PyConcreteFunction>(m, "ConcreteFunction")
+      .def(py::init<>())
+      .def("_build_call_outputs", &PyConcreteFunction::_build_call_outputs);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 46d7596..4a5b028 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -41,6 +41,7 @@
 from tensorflow.python.eager import context
 from tensorflow.python.eager import execute
 from tensorflow.python.eager import forwardprop_util
+from tensorflow.python.eager import monitoring
 from tensorflow.python.eager import tape
 from tensorflow.python.eager.graph_only_ops import graph_placeholder
 from tensorflow.python.framework import c_api_util
@@ -92,6 +93,10 @@
 IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
 SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
 
+_graph_building_time_counter = monitoring.Counter(
+    "/tensorflow/core/tf_function/graph_building_time_usecs",
+    "Time for tf.function to build a graph (us).")
+
 
 def _make_input_signature_hashable(elem):
   """Rewrite input signature to be hashable.
@@ -1746,12 +1751,15 @@
       TypeError: if `args` and `kwargs` do not match the structured signature
         of this `ConcreteFunction`.
     """
-    args, kwargs, flat_args, flat_kwargs = \
+    args, kwargs, _, filtered_flat_args = \
         self._function_spec.canonicalize_function_inputs(*args, **kwargs)
     self._structured_signature_check_missing_args(args, kwargs)
     self._structured_signature_check_unexpected_args(args, kwargs)
     self._structured_signature_check_arg_types(args, kwargs)
-    return self._filtered_call(flat_args, flat_kwargs, cancellation_manager)
+    return self._call_flat(
+        filtered_flat_args,
+        captured_inputs=self.captured_inputs,
+        cancellation_manager=cancellation_manager)
 
   def _structured_signature_check_missing_args(self, args, kwargs):
     """Raises a TypeError if any args are missing."""
@@ -1833,38 +1841,14 @@
                             type(spec_piece).__name__, spec_piece, name,
                             type(arg_piece).__name__, arg_piece))
 
-  def _filtered_call(self, flat_args, flat_kwargs, cancellation_manager=None):
-    """Executes the function, filtering arguments from the Python function.
-
-    Objects aside from Tensors, CompositeTensors, and Variables are ignored.
-    CompositeTensors have been expanded into their components on input.
-
-    Args:
-      flat_args: Flattened canonicalized positional arguments of the Python
-        function.
-      flat_kwargs: Flattened canonicalized keyword arguments of the Python
-        function.
-      cancellation_manager: (Optional.) A `CancellationManager` that can be
-        used to cancel function invocation.
-
-    Returns:
-      The result of applying the function on the Tensors/Variables contained in
-      `flat_args` and `flat_kwargs`.
-    """
-    return self._call_flat([
-        t for t in flat_args + flat_kwargs
-        if isinstance(t, (ops.Tensor,
-                          resource_variable_ops.BaseResourceVariable))
-    ],
-                           captured_inputs=self.captured_inputs,
-                           cancellation_manager=cancellation_manager)
-
   def _call_flat(self, args, captured_inputs, cancellation_manager=None):
     """Executes the wrapped function.
 
     Args:
-      args: a list of Tensors or Variables.  Any CompositeTensors should be
-        expanded before calling this method.
+      args: a list of Tensors or Variables. Arguments from the Python function
+        should be filtered before calling this method: objects aside from
+        Tensors, CompositeTensors, and Variables are ignored. Any
+        CompositeTensors should be expanded before calling this method.
       captured_inputs: the captured inputs that are also part of the input args
         to the actual execution. By default, it should be self._captured_inputs.
       cancellation_manager: (Optional.) A `CancellationManager` that can be
@@ -2167,7 +2151,7 @@
     Returns:
       The actual call output.
     """
-    # TODO(jlchu): implement in C++.
+    # TODO(jlchu): call C++ version in function.cc when speed is improved
     if self._func_graph.structured_outputs is None:
       return result
 
@@ -2607,11 +2591,12 @@
       **kwargs: The keyword args this function was called with.
 
     Returns:
-      A canonicalized ordering of the inputs representened by a tuple in the
-      form (args, kwargs), followed by their flattened versions in the form
-      (flat_args, flat_kwargs). Here: `args` is a full list of bound arguments,
-      and `kwargs` contains only true keyword arguments, as opposed to named
-      arguments called in a keyword-like fashion.
+      A canonicalized ordering of the inputs, as well as full and filtered
+      (Tensors and Variables only) versions of their concatenated flattened
+      representations, represented by a tuple in the form (args, kwargs,
+      flat_args, filtered_flat_args). Here: `args` is a full list of bound
+      arguments, and `kwargs` contains only true keyword arguments, as opposed
+      to named arguments called in a keyword-like fashion.
 
     Raises:
       ValueError: If a keyword in `kwargs` cannot be matched with a positional
@@ -2691,14 +2676,15 @@
           kwargs.setdefault(kwarg, default)
 
     if self._input_signature is None:
-      inputs, flat_inputs = _convert_numpy_inputs(inputs)
-      kwargs, flat_kwargs = _convert_numpy_inputs(kwargs)
-      return inputs, kwargs, flat_inputs, flat_kwargs
+      inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs)
+      kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs)
+      return (inputs, kwargs, flat_inputs + flat_kwargs,
+              filtered_flat_inputs + filtered_flat_kwargs)
     else:
       assert not kwargs
-      inputs, flat_inputs = _convert_inputs_to_signature(
+      inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature(
           inputs, self._input_signature, self._flat_input_signature)
-      return inputs, {}, flat_inputs, []
+      return inputs, {}, flat_inputs, filtered_flat_inputs
 
 
 def _as_ndarray(value):
@@ -2728,7 +2714,7 @@
   # We assume that any CompositeTensors have already converted their components
   # from numpy arrays to Tensors, so we don't need to expand composites here for
   # the numpy array conversion. Instead, we do so because the flattened inputs
-  # are eventually passed to ConcreteFunction()._filtered_call, which requires
+  # are eventually passed to ConcreteFunction()._call_flat, which requires
   # expanded composites.
   flat_inputs = nest.flatten(inputs, expand_composites=True)
 
@@ -2737,20 +2723,28 @@
   # finding a way to store them directly in the cache key (currently not
   # possible since ndarrays are not hashable).
   need_packing = False
+  filtered_flat_inputs = []
   for index, value in enumerate(flat_inputs):
-    if _is_ndarray(value):
+    if isinstance(value,
+                  (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
+      filtered_flat_inputs.append(value)
+    elif hasattr(value, "__array__") and not (
+        hasattr(value, "_should_act_as_resource_variable") or
+        isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))):
+      # This case is equivalent to _is_ndarray(value) == True
       a = _as_ndarray(value)
       if not isinstance(a, np.ndarray):
         raise TypeError("The output of __array__ must be an np.ndarray "
                         "(got {} from {}).".format(type(a), type(value)))
       flat_inputs[index] = constant_op.constant(a)
+      filtered_flat_inputs.append(flat_inputs[index])
       need_packing = True
   if need_packing:
     return (nest.pack_sequence_as(
         structure=inputs, flat_sequence=flat_inputs,
-        expand_composites=True), flat_inputs)
+        expand_composites=True), flat_inputs, filtered_flat_inputs)
   else:
-    return inputs, flat_inputs
+    return inputs, flat_inputs, filtered_flat_inputs
 
 
 def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
@@ -2799,7 +2793,12 @@
         flat_sequence=flatten_inputs,
         expand_composites=True)
 
-  return inputs, nest.flatten(inputs, expand_composites=True)
+  flat_inputs = nest.flatten(inputs, expand_composites=True)
+
+  return (inputs, flat_inputs, [
+      t for t in flat_inputs
+      if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
+  ])
 
 
 class FunctionCache(object):
@@ -2924,9 +2923,10 @@
   def __call__(self, *args, **kwargs):
     """Calls a graph function specialized to the inputs."""
     with self._lock:
-      graph_function, flat_args, flat_kwargs = \
-          self._maybe_define_function(args, kwargs)
-    return graph_function._filtered_call(flat_args, flat_kwargs)  # pylint: disable=protected-access
+      (graph_function,
+       filtered_flat_args) = self._maybe_define_function(args, kwargs)
+    return graph_function._call_flat(
+        filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
 
   @property
   def python_function(self):
@@ -2952,7 +2952,7 @@
     if self.input_signature:
       args, kwargs = None, None
     with self._lock:
-      graph_function, _, _ = self._maybe_define_function(args, kwargs)
+      graph_function, _ = self._maybe_define_function(args, kwargs)
     return graph_function
 
   def _get_concrete_function_internal(self, *args, **kwargs):
@@ -3002,7 +3002,7 @@
                            (str(args), str(self.input_signature)))
       args, kwargs = None, None
     with self._lock:
-      graph_function, _, _ = self._maybe_define_function(args, kwargs)
+      graph_function, _ = self._maybe_define_function(args, kwargs)
       seen_names = set()
       captured = object_identity.ObjectIdentitySet(
           graph_function.graph.internal_captures)
@@ -3073,7 +3073,11 @@
     # Return the cached `Function` for the instance
     return self._descriptor_cache[instance]
 
-  def _cache_key(self, args, kwargs, include_tensor_ranks_only=False):
+  def _cache_key(self,
+                 args,
+                 kwargs,
+                 cache_key_context,
+                 include_tensor_ranks_only=False):
     """Computes the cache key given inputs and execution context."""
     if self.input_signature is None:
       inputs = (args, kwargs) if kwargs else args
@@ -3085,6 +3089,15 @@
       assert not include_tensor_ranks_only
       hashable_input_signature = self._hashable_input_signature
 
+    (parent_graph, device_functions, colocation_stack, in_cross_replica_context,
+     variable_policy, xla_context_id) = cache_key_context
+
+    return CacheKey(hashable_input_signature, parent_graph, device_functions,
+                    colocation_stack, in_cross_replica_context, variable_policy,
+                    xla_context_id)
+
+  def _cache_key_context(self):
+    """Returns execution context."""
     ctx = context.context()
 
     # Don't need to open an init_scope if the _cache_key call is in eager mode
@@ -3153,9 +3166,8 @@
     else:
       variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
 
-    return CacheKey(hashable_input_signature, parent_graph, device_functions,
-                    colocation_stack, in_cross_replica_context, variable_policy,
-                    xla_context_id)
+    return (parent_graph, device_functions, colocation_stack,
+            in_cross_replica_context, variable_policy, xla_context_id)
 
   def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
     """Create a `ConcreteFunction` from `args` and `kwargs`."""
@@ -3196,27 +3208,31 @@
     return graph_function
 
   def _define_function_with_shape_relaxation(self, args, kwargs, flat_args,
-                                             flat_kwargs):
+                                             filtered_flat_args,
+                                             cache_key_context):
     """Define a function, relaxing arg shapes to avoid unnecessary retracing."""
-    flat_args_all = nest.flatten((args, kwargs), expand_composites=False)
+    flat_no_comp = nest.flatten((args, kwargs), expand_composites=False)
 
     any_composite_args = any(
-        isinstance(x, composite_tensor.CompositeTensor) for x in flat_args_all)
+        isinstance(x, composite_tensor.CompositeTensor) for x in flat_no_comp)
 
     # Build a cache key where TensorShapes include only rank information (and
     # not information about the size of each dimension).
     if not any_composite_args:
       rank_only_cache_key = self._cache_key(
-          args, kwargs, include_tensor_ranks_only=True)
+          args, kwargs, cache_key_context, include_tensor_ranks_only=True)
     else:
       # For the rank-only cache key, replace any composite tensors with
       # shape-relaxed TypeSpecs.
       (cache_key_args, cache_key_kwargs) = nest.map_structure(
           _shape_relaxed_type_for_composite_tensor, (args, kwargs))
       rank_only_cache_key = self._cache_key(
-          cache_key_args, cache_key_kwargs, include_tensor_ranks_only=True)
+          cache_key_args,
+          cache_key_kwargs,
+          cache_key_context,
+          include_tensor_ranks_only=True)
 
-    arg_specs = [_type_spec_for(x) for x in flat_args_all]
+    arg_specs = [_type_spec_for(x) for x in flat_no_comp]
     relaxed_arg_specs = self._function_cache.arg_relaxed_specs.get(
         rank_only_cache_key, None)
     relaxed_arg_function = self._function_cache.arg_relaxed.get(
@@ -3225,7 +3241,7 @@
     if (relaxed_arg_function is not None
         and all(_is_type_subset(x, y) for (x, y) in
                 zip(relaxed_arg_specs, arg_specs))):
-      return relaxed_arg_function, flat_args, flat_kwargs
+      return relaxed_arg_function, filtered_flat_args
 
     if relaxed_arg_specs is None:
       relaxed_arg_specs = arg_specs
@@ -3251,15 +3267,18 @@
           (args, kwargs), relaxed_arg_specs, expand_composites=False)
       (args, kwargs) = nest.pack_sequence_as(
           (relaxed_arg_specs, relaxed_kwarg_specs),
-          flat_args + flat_kwargs,
+          flat_args,
           expand_composites=True)
 
     graph_function = self._create_graph_function(
         args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
     self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
 
-    return (graph_function, nest.flatten(args, expand_composites=True),
-            nest.flatten(kwargs, expand_composites=True))
+    return (graph_function, [
+        t for t in nest.flatten((args, kwargs), expand_composites=True)
+        if isinstance(t, (ops.Tensor,
+                          resource_variable_ops.BaseResourceVariable))
+    ])
 
   def _maybe_define_function(self, args, kwargs):
     """Gets a function for these inputs, defining it if necessary.
@@ -3275,7 +3294,8 @@
 
     Returns:
       A graph function corresponding to the input signature implied by args and
-      kwargs, as well as flattened inputs that the object should be called with.
+      kwargs, as well as filtered flattened inputs (only Tensors and Variables)
+      that the object should be called with.
 
     Raises:
       ValueError: If inputs are incompatible with the input signature.
@@ -3284,12 +3304,13 @@
         shape relaxation retracing.
     """
     if self.input_signature is None or args is not None or kwargs is not None:
-      args, kwargs, flat_args, flat_kwargs = \
+      args, kwargs, flat_args, filtered_flat_args = \
           self._function_spec.canonicalize_function_inputs(*args, **kwargs)
     else:
-      flat_args, flat_kwargs = [None], [None]
+      flat_args, filtered_flat_args = [None], []
 
-    cache_key = self._cache_key(args, kwargs)
+    cache_key_context = self._cache_key_context()
+    cache_key = self._cache_key(args, kwargs, cache_key_context)
 
     try:
       hash(cache_key)
@@ -3300,43 +3321,44 @@
 
     graph_function = self._function_cache.primary.get(cache_key, None)
     if graph_function is not None:
-      return graph_function, flat_args, flat_kwargs
+      return graph_function, filtered_flat_args
 
-    logging.vlog(1,
-                 "Creating new FuncGraph for Python function %r (key: %r)",
-                 self._python_function, cache_key)
-    logging.vlog(2,
-                 "Python function signature [args: %s] [kwargs: %s]",
-                 args,
-                 kwargs)
+    with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()):
+      with trace.Trace("tf.function-graph_building"):
+        logging.vlog(1,
+                     "Creating new FuncGraph for Python function %r (key: %r)",
+                     self._python_function, cache_key)
+        logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]",
+                     args, kwargs)
 
-    # pylint: disable=protected-access
-    call_context_key = cache_key._replace(input_signature=None)
-    # pylint: disable=protected-access
+        # pylint: disable=protected-access
+        call_context_key = cache_key._replace(input_signature=None)
+        # pylint: disable=protected-access
 
-    ag_status = (
-        ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED)
-    with ag_ctx.ControlStatusCtx(
-        status=ag_status, options=self._autograph_options):
+        ag_status = (
+            ag_ctx.Status.ENABLED
+            if self._autograph else ag_ctx.Status.DISABLED)
+        with ag_ctx.ControlStatusCtx(
+            status=ag_status, options=self._autograph_options):
 
-      # Build a function with shape relaxation retracing if:
-      # 1. shape relaxation is explicitly enabled
-      # and 2. there's no provided input signature
-      # and 3. there's been a cache miss for this calling context
-      if (self._experimental_relax_shapes
-          and self.input_signature is None
-          and call_context_key in self._function_cache.missed):
-        return self._define_function_with_shape_relaxation(
-            args, kwargs, flat_args, flat_kwargs)
+          # Build a function with shape relaxation retracing if:
+          # 1. shape relaxation is explicitly enabled
+          # and 2. there's no provided input signature
+          # and 3. there's been a cache miss for this calling context
+          if (self._experimental_relax_shapes and
+              self.input_signature is None and
+              call_context_key in self._function_cache.missed):
+            return self._define_function_with_shape_relaxation(
+                args, kwargs, flat_args, filtered_flat_args, cache_key_context)
 
-      self._function_cache.missed.add(call_context_key)
-      graph_function = self._create_graph_function(args, kwargs)
-      self._function_cache.primary[cache_key] = graph_function
+          self._function_cache.missed.add(call_context_key)
+          graph_function = self._create_graph_function(args, kwargs)
+          self._function_cache.primary[cache_key] = graph_function
 
-      if ops.get_default_graph()._distribution_strategy_stack:
-        self._traced_with_distribution_strategy = True
+          if ops.get_default_graph()._distribution_strategy_stack:
+            self._traced_with_distribution_strategy = True
 
-      return graph_function, flat_args, flat_kwargs
+          return graph_function, filtered_flat_args
 
 
 def register(func, *args, **kwargs):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 65b2340..57e5e17 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -56,6 +56,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework import type_spec
 from tensorflow.python.layers import convolutional
+from tensorflow.python.module import module
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import clip_ops
@@ -4220,6 +4221,25 @@
     enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700)  # No retrace
     self.assertEqual(trace_count[0], 4)
 
+  def testWithModuleNameScope(self):
+    self.skipTest('b/166158748:function does not handle this case correctly.')
+
+    class Foo(module.Module):
+
+      def __init__(self):
+        super().__init__()
+        self.var = None
+
+      @def_function.function
+      @module.Module.with_name_scope
+      def bar(self, x, y):
+        if self.var is None:
+          return x
+
+    foo = Foo()
+    with self.assertRaisesRegex(TypeError, 'got two values for argument'):
+      foo.bar(2, x=3)  # pylint: disable=redundant-keyword-arg
+
 
 class MultiDeviceTest(test.TestCase, parameterized.TestCase):
 
diff --git a/tensorflow/python/eager/remote_benchmarks_test.py b/tensorflow/python/eager/remote_benchmarks_test.py
index 300ce0c..7343713 100644
--- a/tensorflow/python/eager/remote_benchmarks_test.py
+++ b/tensorflow/python/eager/remote_benchmarks_test.py
@@ -92,7 +92,7 @@
         wall_time=mean_us,
         extras={"examples_per_sec": num_iters / total_time})
 
-  def benchmark_send_mirroring_off(self):
+  def benchmark_send(self):
     remote.connect_to_remote_host(self._cached_server_target1)
 
     x = random_ops.random_uniform((2, 2)).cpu()
@@ -105,34 +105,13 @@
       with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
         return remote_func(m)
 
-    context.context().mirroring_policy = context.MIRRORING_NONE
     self._run(lambda: func(x))
     # NOTE(b/136184459): Force garbage collecting hanging resources before
     # subsequent calls to set_server_def, to ensure the destroy resource ops are
     # executed when their corresponding device and manager are still available.
     gc.collect()
 
-  def benchmark_send_mirroring_on(self):
-    remote.connect_to_remote_host(self._cached_server_target1)
-
-    x = random_ops.random_uniform((2, 2)).cpu()
-
-    @def_function.function
-    def remote_func(m):
-      return math_ops.matmul(m, m)
-
-    def func(m):
-      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
-        return remote_func(m)
-
-    context.context().mirroring_policy = context.MIRRORING_ALL
-    self._run(lambda: func(x))
-    # NOTE(b/136184459): Force garbage collecting hanging resources before
-    # subsequent calls to set_server_def, to ensure the destroy resource ops are
-    # executed when their corresponding device and manager are still available.
-    gc.collect()
-
-  def benchmark_worker_mirroring_off(self):
+  def benchmark_worker_recv(self):
     remote.connect_to_remote_host(
         [self._cached_server_target1, self._cached_server_target2])
 
@@ -147,29 +126,6 @@
       with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
         return remote_func()
 
-    context.context().mirroring_policy = context.MIRRORING_NONE
-    self._run(func)
-    # NOTE(b/136184459): Force garbage collecting hanging resources before
-    # subsequent calls to set_server_def, to ensure the destroy resource ops are
-    # executed when their corresponding device and manager are still available.
-    gc.collect()
-
-  def benchmark_worker_mirroring_on(self):
-    remote.connect_to_remote_host(
-        [self._cached_server_target1, self._cached_server_target2])
-
-    with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
-      v = variables.Variable(1.0)
-
-    @def_function.function
-    def remote_func():
-      return 1.0 + v
-
-    def func():
-      with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
-        return remote_func()
-
-    context.context().mirroring_policy = context.MIRRORING_ALL
     self._run(func)
     # NOTE(b/136184459): Force garbage collecting hanging resources before
     # subsequent calls to set_server_def, to ensure the destroy resource ops are
diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py
index 4290681..0fb78cb 100644
--- a/tensorflow/python/eager/remote_test.py
+++ b/tensorflow/python/eager/remote_test.py
@@ -468,17 +468,6 @@
       c = a + 1.0
       return c
 
-    context.context().mirroring_policy = context.MIRRORING_NONE
-
-    with ops.device('/job:worker/replica:0/task:0'):
-      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
-
-    if test_util.is_gpu_available():
-      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
-        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
-
-    context.context().mirroring_policy = context.MIRRORING_ALL
-
     with ops.device('/job:worker/replica:0/task:0'):
       self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
 
@@ -520,17 +509,6 @@
 
       return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0]
 
-    context.context().mirroring_policy = context.MIRRORING_NONE
-
-    with ops.device('/job:worker/replica:0/task:0'):
-      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
-
-    if test_util.is_gpu_available():
-      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
-        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
-
-    context.context().mirroring_policy = context.MIRRORING_ALL
-
     with ops.device('/job:worker/replica:0/task:0'):
       self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
 
diff --git a/tensorflow/python/framework/composite_tensor.py b/tensorflow/python/framework/composite_tensor.py
index b7a4d65..e3db993 100644
--- a/tensorflow/python/framework/composite_tensor.py
+++ b/tensorflow/python/framework/composite_tensor.py
@@ -58,8 +58,8 @@
 
     Args:
       shape: A `tf.TensorShape` object.  The shape invariant for this
-        `CompositeTensor`, or `None` if a default shape invariant should be
-        used (based on the value of this `CompositeTensor`).
+        `CompositeTensor`, or `None` if a default shape invariant should be used
+        (based on the value of this `CompositeTensor`).
 
     Returns:
       A nested structure whose values are `tf.TensorShape` objects, specifying
@@ -68,8 +68,8 @@
     # New TypeSpec subclasses generally do not need to implement this --
     # this method is used for backwards compatibility.  Users of tf.while_loop
     # can specify a type by passing in TypeSpec instead.
-    raise NotImplementedError("%s._shape_invariant_to_type_spec"
-                              % type(self).__name__)
+    raise NotImplementedError("%s._shape_invariant_to_type_spec" %
+                              type(self).__name__)
 
   def _consumers(self):
     """Returns a list of `Operation`s that consume this `CompositeTensor`.
@@ -105,12 +105,13 @@
     returns the same value as `nest.flatten(structure)`.
   """
   if isinstance(structure, CompositeTensor):
-    return replace_composites_with_components(structure._to_components())  # pylint: disable=protected-access
+    return replace_composites_with_components(
+        structure._type_spec._to_components(structure))  # pylint: disable=protected-access
   elif not nest.is_sequence(structure):
     return structure
   else:
-    return nest.map_structure(replace_composites_with_components, structure,
-                              expand_composites=False)
+    return nest.map_structure(
+        replace_composites_with_components, structure, expand_composites=False)
 
 
 # @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just
diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py
index 0962b9a..6593c05 100644
--- a/tensorflow/python/framework/config.py
+++ b/tensorflow/python/framework/config.py
@@ -24,34 +24,73 @@
 from tensorflow.python.util.tf_export import tf_export
 
 
-# No tf_export until TF is built against CUDA11 which is required for TF32.
-def tensor_float_32_execution_allowed():
-  """Get if TensorFloat-32 operations are enabled on supported hardware.
+@tf_export('config.experimental.tensor_float_32_execution_enabled')
+def tensor_float_32_execution_enabled():
+  """Returns whether TensorFloat-32 is enabled.
+
+  By default, TensorFloat-32 is enabled, but this can be changed with
+  `tf.config.experimental.enable_tensor_float_32_execution`.
 
   Returns:
-    True if TensorFloat-32 execution is enabled and False otherwise.
+    True if TensorFloat-32 is enabled (the default) and False otherwise
   """
   return _pywrap_tf32_execution.is_allowed()
 
 
-# No tf_export until TF is built against CUDA11 which is required for TF32.
-def allow_tensor_float_32_execution(allowed):
-  """Allow use of TensorFloat-32 with float32 ops on supported hardware.
+@tf_export('config.experimental.enable_tensor_float_32_execution')
+def enable_tensor_float_32_execution(enabled):
+  """Enable or disable the use of TensorFloat-32 on supported hardware.
 
-  TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
-  TensorFloat-32 kernels take float32 inputs and produce float32 outputs.
-  Internally, the inputs are cast to a custom representation with 10-bit
-  mantissa (similar to float16) and 8-bit exponent (similar to float32) and are
-  executed using TensorCores with float32 accumulation. For more information,
-  see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
+  [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format),
+  or TF32 for short, is a math mode for NVIDIA Ampere GPUs. TensorFloat-32
+  execution causes certain float32 ops, such as matrix multiplications and
+  convolutions, to run much faster on Ampere GPUs but with reduced precision.
+  This reduced precision should not impact convergence of deep learning models
+  in practice.
 
-  TensorFloat-32 execution is disabled by default, but this may change in a
-  future version.
+  TensorFloat-32 is enabled by default in the nightly versions of TensorFlow. We
+  expect it will remain enabled by default in the first stable version that
+  TensorFloat-32 is available, which is TensorFlow 2.4, as it increases
+  performance and does not reduce model quality in practice. If you want to use
+  the full float32 precision, you can disable TensorFloat-32 execution with this
+  function. For example:
+
+  ```python
+  x = tf.fill((2, 2), 1.0001)
+  y = tf.fill((2, 2), 1.)
+  # TensorFloat-32 is enabled, so matmul is run with reduced precision
+  print(tf.linalg.matmul(x, y))  # [[2., 2.], [2., 2.]]
+  tf.config.experimental.enable_tensor_float_32_execution(False)
+  # Matmul is run with full precision
+  print(tf.linalg.matmul(x, y))  # [[2.0002, 2.0002], [2.0002, 2.0002]]
+  ```
+
+  We soon will create an RFC proposing that TensorFloat-32 remain enabled by
+  default in stable versions of TensorFlow. We expect the RFC to be accepted,
+  but if it isn't, TensorFloat-32 will be disabled by default in TensorFlow
+  2.4.
+
+  To check whether TensorFloat-32 execution is currently enabled, use
+  `tf.config.experimental.tensor_float_32_execution_enabled`.
+
+  Enabling TensorFloat-32 causes float32 inputs of supported ops, such as
+  `tf.linalg.matmul`, to be rounded from 23 bits of precision to 10 bits of
+  precision in most cases. This allows the ops to execute much faster by
+  utilizing the GPU's tensor cores. TensorFloat-32 has the same dynamic range as
+  float32, meaning it is no more likely to underflow or overflow than float32.
+  Ops still use float32 accumulation when TensorFloat-32 is enabled. Enabling
+  TensorFloat-32 only affects Ampere GPUs and subsequent GPUs that support
+  TensorFloat-32.
+
+  Note TensorFloat-32 is not always used in supported ops, as only inputs of
+  certain shapes are supported. Support for more input shapes and more ops may
+  be added in the future. As a result, precision of float32 ops may decrease in
+  minor versions of TensorFlow.
 
   Args:
-    allowed: whether to allow TensorFloat-32 execution
+    enabled: Bool indicating whether to enable TensorFloat-32 execution.
   """
-  _pywrap_tf32_execution.allow(allowed)
+  _pywrap_tf32_execution.allow(enabled)
 
 
 @tf_export('config.threading.get_intra_op_parallelism_threads')
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index ee7e111..d2314f1 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -750,18 +750,18 @@
 class TensorFloat32Test(test.TestCase):
 
   def setUp(self):
+    super(TensorFloat32Test, self).setUp()
     if not test_util.is_gpu_available(
         cuda_only=True, min_cuda_compute_capability=(8, 0)):
       self.skipTest('TensorFloat-32 requires an NVIDIA GPU with compute '
                     'capability of at least 8.0')
 
   def tearDown(self):
-    config.allow_tensor_float_32_execution(False)
+    super(TensorFloat32Test, self).tearDown()
+    config.enable_tensor_float_32_execution(True)
 
   def test_tf32_enabled(self):
-    self.assertFalse(config.tensor_float_32_execution_allowed())
-    config.allow_tensor_float_32_execution(True)
-    self.assertTrue(config.tensor_float_32_execution_allowed())
+    self.assertTrue(config.tensor_float_32_execution_enabled())
 
     x = array_ops.fill((8, 8), 1 + 2**-20)
     y = array_ops.ones((8, 8))
@@ -771,19 +771,16 @@
     self.assertAllEqual(out, expected)
 
   def test_tf32_disabled(self):
+    self.assertTrue(config.tensor_float_32_execution_enabled())
+    config.enable_tensor_float_32_execution(False)
+    self.assertFalse(config.tensor_float_32_execution_enabled())
+
     x = array_ops.fill((8, 8), 1 + 2**-20)
     y = array_ops.ones((8, 8))
     out = math_ops.matmul(x, y)
     expected = array_ops.fill((8, 8), 8 * (1 + 2**-20))
     self.assertAllEqual(out, expected)
 
-    # Test disabling tf32 after enabling it works correctly
-    config.allow_tensor_float_32_execution(True)
-    config.allow_tensor_float_32_execution(False)
-    self.assertFalse(config.tensor_float_32_execution_allowed())
-    out = math_ops.matmul(x, y)
-    self.assertAllEqual(out, expected)
-
 
 if __name__ == '__main__':
   ops.enable_eager_execution()
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 596b932..298d41a 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1592,6 +1592,7 @@
       self.assertAllClose(mv0, mv2, rtol=1e-4)
       self.assertAllClose(mv0, mv3, rtol=1e-4)
 
+  @test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function")
   def testUnrollLSTMGrad(self):
     # Run one step of the unrolled lstm graph.
     def RunForwardBackward(mode, cfg=None):
diff --git a/tensorflow/python/framework/indexed_slices.py b/tensorflow/python/framework/indexed_slices.py
index 45f6e25..b1e1e20 100644
--- a/tensorflow/python/framework/indexed_slices.py
+++ b/tensorflow/python/framework/indexed_slices.py
@@ -429,9 +429,12 @@
             "elements. This may consume a large amount of memory." %
             num_elements)
     else:
-      warnings.warn(
-          "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
-          "This may consume a large amount of memory.")
+      if value.dense_shape.op.type != "VariableShape":
+        # VariableShape may hide static shapes behind a resource handle
+        # producing a warning that isn't that useful to users.
+        warnings.warn(
+            "Converting sparse IndexedSlices(%s) to a dense Tensor of unknown "
+            "shape. This may consume a large amount of memory." % value)
   return math_ops.unsorted_segment_sum(
       value.values, value.indices, value.dense_shape[0], name=name)
 
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index 53d0927..016af65 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -63,6 +63,27 @@
            ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
 
 
+def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name):
+  if attr_def.has_minimum and length < attr_def.minimum:
+    raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
+                     "less than minimum %d." %
+                     (param_name, op_type_name, length, attr_def.minimum))
+
+
+def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name):
+  if value not in attr_def.allowed_values.list.s:
+    raise ValueError(
+        "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
+        (arg_name, op_type_name, compat.as_text(value), '", "'.join(
+            map(compat.as_text, attr_def.allowed_values.list.s))))
+
+
+def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name):
+  if value < attr_def.minimum:
+    raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." %
+                     (arg_name, op_type_name, value, attr_def.minimum))
+
+
 def _IsListParameter(arg):
   if arg.number_attr:
     return True
@@ -172,15 +193,13 @@
   return v
 
 
-def _MakeType(v, attr_def):
+def _MakeType(v, arg_name):
   try:
     v = dtypes.as_dtype(v).base_dtype
   except TypeError:
     raise TypeError("Expected DataType for argument '%s' not %s." %
-                    (attr_def.name, repr(v)))
-  i = v.as_datatype_enum
-  _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name)
-  return i
+                    (arg_name, repr(v)))
+  return v.as_datatype_enum
 
 
 def _MakeShape(v, arg_name):
@@ -670,78 +689,32 @@
     for attr_def in op_def.attr:
       key = attr_def.name
       value = attrs[key]
-      attr_value = attr_value_pb2.AttrValue()
+
       if attr_def.HasField("default_value") and value is None:
+        attr_value = attr_value_pb2.AttrValue()
         attr_value.CopyFrom(attr_def.default_value)
         attr_protos[key] = attr_value
         continue
+
+      attr_value = value_to_attr_value(value, attr_def.type, key)
       if attr_def.type.startswith("list("):
-        if not _IsListValue(value):
-          raise TypeError("Expected list for attr " + key)
-        if attr_def.has_minimum:
-          if len(value) < attr_def.minimum:
-            raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
-                             "less than minimum %d." %
-                             (key, op_type_name, len(value),
-                              attr_def.minimum))
-        attr_value.list.SetInParent()
-      if attr_def.type == "string":
-        attr_value.s = _MakeStr(value, key)
-        if attr_def.HasField("allowed_values"):
-          if attr_value.s not in attr_def.allowed_values.list.s:
-            raise ValueError(
-                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
-                (key, op_type_name, compat.as_text(attr_value.s),
-                 '", "'.join(map(compat.as_text,
-                                 attr_def.allowed_values.list.s))))
-      elif attr_def.type == "list(string)":
-        attr_value.list.s.extend([_MakeStr(x, key) for x in value])
-        if attr_def.HasField("allowed_values"):
-          for x in attr_value.list.s:
-            if x not in attr_def.allowed_values.list.s:
-              raise ValueError(
-                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
-                  (key, op_type_name, compat.as_text(x),
-                   '", "'.join(map(compat.as_text,
-                                   attr_def.allowed_values.list.s))))
-      elif attr_def.type == "int":
-        attr_value.i = _MakeInt(value, key)
-        if attr_def.has_minimum:
-          if attr_value.i < attr_def.minimum:
-            raise ValueError(
-                "Attr '%s' of '%s' Op passed %d less than minimum %d." %
-                (key, op_type_name, attr_value.i, attr_def.minimum))
-      elif attr_def.type == "list(int)":
-        attr_value.list.i.extend([_MakeInt(x, key) for x in value])
-      elif attr_def.type == "float":
-        attr_value.f = _MakeFloat(value, key)
-      elif attr_def.type == "list(float)":
-        attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
-      elif attr_def.type == "bool":
-        attr_value.b = _MakeBool(value, key)
-      elif attr_def.type == "list(bool)":
-        attr_value.list.b.extend([_MakeBool(x, key) for x in value])
-      elif attr_def.type == "type":
-        attr_value.type = _MakeType(value, attr_def)
-      elif attr_def.type == "list(type)":
-        attr_value.list.type.extend(
-            [_MakeType(x, attr_def) for x in value])
-      elif attr_def.type == "shape":
-        attr_value.shape.CopyFrom(_MakeShape(value, key))
-      elif attr_def.type == "list(shape)":
-        attr_value.list.shape.extend(
-            [_MakeShape(x, key) for x in value])
-      elif attr_def.type == "tensor":
-        attr_value.tensor.CopyFrom(_MakeTensor(value, key))
-      elif attr_def.type == "list(tensor)":
-        attr_value.list.tensor.extend(
-            [_MakeTensor(x, key) for x in value])
-      elif attr_def.type == "func":
-        attr_value.func.CopyFrom(_MakeFunc(value, key))
-      elif attr_def.type == "list(func)":
-        attr_value.list.func.extend([_MakeFunc(x, key) for x in value])
-      else:
-        raise TypeError("Unrecognized Attr type " + attr_def.type)
+        _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
+      if attr_def.HasField("allowed_values"):
+        if attr_def.type == "string":
+          _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
+                                             op_type_name)
+        elif attr_def.type == "list(string)":
+          for value in attr_value.list.s:
+            _SatisfiesAllowedStringsConstraint(value, attr_def, key,
+                                               op_type_name)
+      if attr_def.has_minimum and attr_def.type == "int":
+        _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
+                                       op_type_name)
+      if attr_def.type == "type":
+        _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
+      if attr_def.type == "list(type)":
+        for value in attr_value.list.type:
+          _SatisfiesTypeConstraint(value, attr_def, key)
 
       attr_protos[key] = attr_value
     del attrs  # attrs is no longer authoritative, use attr_protos instead
@@ -792,6 +765,61 @@
     return output_structure, op_def.is_stateful, op, outputs
 
 
+def value_to_attr_value(value, attr_type, arg_name):  # pylint: disable=invalid-name
+  """Encodes a Python value as an `AttrValue` proto message.
+
+  Args:
+    value: The value to convert.
+    attr_type: The value type (string) -- see the AttrValue proto definition for
+      valid strings.
+    arg_name: Argument name (for error messages).
+
+  Returns:
+    An AttrValue proto message that encodes `value`.
+  """
+  attr_value = attr_value_pb2.AttrValue()
+
+  if attr_type.startswith("list("):
+    if not _IsListValue(value):
+      raise TypeError("Expected list for attr " + arg_name)
+
+  if attr_type == "string":
+    attr_value.s = _MakeStr(value, arg_name)
+  elif attr_type == "list(string)":
+    attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value])
+  elif attr_type == "int":
+    attr_value.i = _MakeInt(value, arg_name)
+  elif attr_type == "list(int)":
+    attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value])
+  elif attr_type == "float":
+    attr_value.f = _MakeFloat(value, arg_name)
+  elif attr_type == "list(float)":
+    attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value])
+  elif attr_type == "bool":
+    attr_value.b = _MakeBool(value, arg_name)
+  elif attr_type == "list(bool)":
+    attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value])
+  elif attr_type == "type":
+    attr_value.type = _MakeType(value, arg_name)
+  elif attr_type == "list(type)":
+    attr_value.list.type.extend([_MakeType(x, arg_name) for x in value])
+  elif attr_type == "shape":
+    attr_value.shape.CopyFrom(_MakeShape(value, arg_name))
+  elif attr_type == "list(shape)":
+    attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value])
+  elif attr_type == "tensor":
+    attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name))
+  elif attr_type == "list(tensor)":
+    attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value])
+  elif attr_type == "func":
+    attr_value.func.CopyFrom(_MakeFunc(value, arg_name))
+  elif attr_type == "list(func)":
+    attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value])
+  else:
+    raise TypeError("Unrecognized Attr type " + attr_type)
+  return attr_value
+
+
 # The following symbols are used by op_def_util.cc.
 _pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType)
 _pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f07bca1..7e51d3a 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1506,6 +1506,13 @@
 
   if preferred_dtype is not None:
     preferred_dtype = dtypes.as_dtype(preferred_dtype)
+
+  # See below for the reason why it's `type(value)` and not just `value`.
+  # https://docs.python.org/3.8/reference/datamodel.html#special-lookup
+  overload = getattr(type(value), "__tf_tensor__", None)
+  if overload is not None:
+    return overload(value, dtype, name)
+
   for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
     # If dtype is None but preferred_dtype is not None, we try to
     # cast to preferred_dtype first.
@@ -2333,6 +2340,10 @@
   def __repr__(self):
     return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
 
+  def __tf_tensor__(self, dtype=None, name=None):
+    """Raises a helpful error."""
+    raise TypeError("can't convert Operation '{}' to Tensor".format(self.name))
+
   @property
   def outputs(self):
     """The list of `Tensor` objects representing the outputs of this op."""
@@ -6833,13 +6844,6 @@
     return None
 
 
-def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
-  """Produce a nice error if someone converts an Operation to a Tensor."""
-  raise TypeError(("Can't convert Operation '%s' to Tensor "
-                   "(target dtype=%r, name=%r, as_ref=%r)") %
-                  (op.name, dtype, name, as_ref))
-
-
 def _op_to_colocate_with(v, graph):
   """Operation object corresponding to v to use for colocation constraints."""
   if v is None:
@@ -6873,10 +6877,6 @@
   return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
 
 
-tensor_conversion_registry.register_tensor_conversion_function(
-    Operation, _operation_conversion_error)
-
-
 # These symbols were originally defined in this module; import them for
 # backwards compatibility until all references have been updated to access
 # them from the indexed_slices.py module.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 4129b55..58e3f65 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -858,12 +858,25 @@
     with self.assertRaises(ValueError):
       ops.convert_to_tensor(tensor, dtype=dtypes.int32)
 
+  @test_util.run_in_graph_and_eager_modes
+  def testConvertToTensorProtocol(self):
+    class TensorCompatible:
+
+      def __tf_tensor__(self, dtype=None, name=None):
+        return constant_op.constant((1, 2, 3), dtype=dtype, name=name)
+
+    tc = TensorCompatible()
+
+    tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
+    self.assertEqual(tensor.dtype, dtypes.int32)
+    self.assertAllEqual((1, 2, 3), self.evaluate(tensor))
+
   @test_util.run_deprecated_v1
   def testNoConvert(self):
     # Operation cannot be converted to Tensor.
     op = control_flow_ops.no_op()
     with self.assertRaisesRegex(TypeError,
-                                r"Can't convert Operation '.*' to Tensor"):
+                                "can't convert Operation '.+' to Tensor"):
       ops.convert_to_tensor(op)
 
   def testStr(self):
diff --git a/tensorflow/python/framework/py_context_manager.cc b/tensorflow/python/framework/py_context_manager.cc
new file mode 100644
index 0000000..b895701
--- /dev/null
+++ b/tensorflow/python/framework/py_context_manager.cc
@@ -0,0 +1,74 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/framework/py_context_manager.h"
+
+#include <map>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+bool PyContextManager::Enter(PyObject* py_context_manager) {
+  if (context_manager_) {
+    PyErr_SetString(
+        PyExc_ValueError,
+        "tensorflow::PyContextManager::Enter must be called at most once.");
+  }
+  if (!py_context_manager) return false;
+  context_manager_.reset(py_context_manager);
+  static char _enter[] = "__enter__";
+  var_.reset(PyObject_CallMethod(context_manager_.get(), _enter, nullptr));
+  return var_ != nullptr;
+}
+
+PyContextManager::~PyContextManager() {
+  if (var_) {
+    static char _exit[] = "__exit__";
+    static char _ooo[] = "OOO";
+    if (PyErr_Occurred()) {
+      PyObject *type, *value, *traceback;
+      PyErr_Fetch(&type, &value, &traceback);
+      value = value ? value : Py_None;
+      traceback = traceback ? traceback : Py_None;
+      Safe_PyObjectPtr result(PyObject_CallMethod(
+          context_manager_.get(), _exit, _ooo, type, value, traceback));
+      if (result) {
+        if (PyObject_IsTrue(result.get())) {
+          PyErr_SetString(
+              PyExc_ValueError,
+              "tensorflow::PyContextManager::Enter does not support "
+              "context managers that suppress exceptions.");
+        } else {
+          PyErr_Restore(type, value, traceback);
+        }
+      }
+    } else {
+      PyObject* result = PyObject_CallMethod(context_manager_.get(), _exit,
+                                             _ooo, Py_None, Py_None, Py_None);
+      if (result) {
+        Py_DECREF(result);
+      } else {
+        LOG(ERROR)
+            << "A context manager wrapped by tensorflow::PyContextManager "
+               "raised a new exception from its __new__ method.  This behavior "
+               "is not supported by PyContextManager, and the exception is "
+               "being suppressed.";
+        PyErr_Clear();
+      }
+    }
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/python/framework/py_context_manager.h b/tensorflow/python/framework/py_context_manager.h
new file mode 100644
index 0000000..6c15fcc
--- /dev/null
+++ b/tensorflow/python/framework/py_context_manager.h
@@ -0,0 +1,78 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
+#define TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
+
+#include <Python.h>
+
+#include <string>
+
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
+
+namespace tensorflow {
+
+// Class that wraps a Python context manager, and calls the `__enter__` and
+// `__exit__` methods at appropriate times:
+//
+// * When `PyContextManager::Enter(cm)` is called, the context manager `cm`
+//   is stored, and `cm.__enter__` is called.  The result can be retrieved
+//   with `PyContextManager::var()`.
+// * When the `PyContextManager` is destroyed, then `cm.__exit__` is called
+//   (with information about any active exception).
+// * `PyContextManager::Enter(cm)` may be called at most once. If
+//   `PyContextManager::Enter()` is never called, then the destructor is a
+//   no-op (i.e., `__exit__` is not called).
+//
+// PyContextManager places two restrictons on the wrapped context managers:
+//
+// 1. The context manager may not suppress exceptions -- i.e., `__exit__`
+//    may not return a True value.  If it does, then a new exception will be
+//    set, indicating that this is unuspported.
+// 2. The context manager may not raise an exception from `__exit__` if the
+//    an exception is not active when it is called.  If it does, then an error
+//    message will be logged, indicating that this is unsupported, and the
+//    exception will be suppressed.
+//
+// These restrictions are both intended to ensure that the state of
+// PyErr_Occured is unchanged by PyContextManager's destructor.  This is
+// important, because changing the state of PyErr_Occurred in the destructor
+// would mean that we are returning a nullptr with no exception set, or
+// returning a non-null value with an exception set (both of which are invalid).
+class PyContextManager {
+ public:
+  // Calls `py_context_manager.__enter__()`, and stores the result in `var`.
+  // Return true if `__enter__` succeeds, or false if `__enter__` raises an
+  // exception.  (Also returns false if `py_context_manager` is nullptr.)
+  //
+  // Steals a reference to `py_context_manager`.  (This reference is deleted
+  // when the destructor is called.)
+  bool Enter(PyObject* py_context_manager);
+
+  // Calls `py_context_manager.__exit__()`.
+  ~PyContextManager();
+
+  // Returns the variable returned by `context_manager.__enter__()`.
+  // (This is the `var` bound by `with context_manager as var`.)
+  // Returns a borrowed reference.
+  PyObject* var() { return var_.get(); }
+
+ protected:
+  Safe_PyObjectPtr context_manager_;
+  Safe_PyObjectPtr var_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_PYTHON_FRAMEWORK_PY_CONTEXT_MANAGER_H_
diff --git a/tensorflow/python/framework/py_context_manager_pybind.cc b/tensorflow/python/framework/py_context_manager_pybind.cc
new file mode 100644
index 0000000..3456514
--- /dev/null
+++ b/tensorflow/python/framework/py_context_manager_pybind.cc
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+#include "tensorflow/python/framework/py_context_manager.h"
+
+namespace py = pybind11;
+
+namespace {
+
+// Test harness for PyContextManager.  Creates a PyContextManager `cm` that
+// wraps `context_manager`, calls `cm.Enter()`, and then calls `body_func`
+// with `cm.var()`.  Returns the result of the function.
+py::handle TestPyContextManager(py::handle context_manager,
+                                py::handle body_func) {
+  tensorflow::Safe_PyObjectPtr result;
+  {
+    tensorflow::PyContextManager cm;
+    Py_INCREF(context_manager.ptr());  // cm.Enter steals a reference.
+    if (!cm.Enter(context_manager.ptr())) {
+      throw py::error_already_set();
+    }
+    result.reset(
+        PyObject_CallFunctionObjArgs(body_func.ptr(), cm.var(), nullptr));
+  }
+  // cm gets destroyed here.
+
+  if (result) {
+    return result.release();
+  } else {
+    throw py::error_already_set();
+  }
+}
+
+}  // namespace
+
+PYBIND11_MODULE(_py_context_manager, m) {
+  m.def("test_py_context_manager", TestPyContextManager);
+}
diff --git a/tensorflow/python/framework/py_context_manager_test.py b/tensorflow/python/framework/py_context_manager_test.py
new file mode 100644
index 0000000..60c72a8
--- /dev/null
+++ b/tensorflow/python/framework/py_context_manager_test.py
@@ -0,0 +1,118 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework._py_context_manager."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import _py_context_manager
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class TestContextManager(object):
+
+  def __init__(self, behavior="basic"):
+    self.log = []
+    self.behavior = behavior
+
+  def __enter__(self):
+    self.log.append("__enter__()")
+    if self.behavior == "raise_from_enter":
+      raise ValueError("exception in __enter__")
+    return "var"
+
+  def __exit__(self, ex_type, ex_value, ex_tb):
+    self.log.append("__exit__(%s, %s, %s)" % (ex_type, ex_value, ex_tb))
+    if self.behavior == "raise_from_exit":
+      raise ValueError("exception in __exit__")
+    if self.behavior == "suppress_exception":
+      return True
+
+
+# Expected log when the body doesn't raise an exception.
+NO_EXCEPTION_LOG = """\
+__enter__()
+body('var')
+__exit__(None, None, None)"""
+
+# Expected log when the body does raise an exception.  (Regular expression.)
+EXCEPTION_LOG = """\
+__enter__\\(\\)
+body\\('var'\\)
+__exit__\\(<class 'ValueError'>, Foo, <traceback object.*>\\)"""
+
+
+class OpDefUtilTest(test_util.TensorFlowTestCase):
+
+  def testBasic(self):
+    cm = TestContextManager()
+
+    def body(var):
+      cm.log.append("body(%r)" % var)
+
+    _py_context_manager.test_py_context_manager(cm, body)
+    self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG)
+
+  def testBodyRaisesException(self):
+    cm = TestContextManager()
+
+    def body(var):
+      cm.log.append("body(%r)" % var)
+      raise ValueError("Foo")
+
+    with self.assertRaisesRegexp(ValueError, "Foo"):
+      _py_context_manager.test_py_context_manager(cm, body)
+    self.assertRegex("\n".join(cm.log), EXCEPTION_LOG)
+
+  def testEnterRaisesException(self):
+    cm = TestContextManager("raise_from_enter")
+
+    def body(var):
+      cm.log.append("body(%r)" % var)
+
+    with self.assertRaisesRegexp(ValueError, "exception in __enter__"):
+      _py_context_manager.test_py_context_manager(cm, body)
+    self.assertEqual("\n".join(cm.log), "__enter__()")
+
+  # Test behavior in unsupported case where __exit__ raises an exception.
+  def testExitRaisesException(self):
+    cm = TestContextManager("raise_from_exit")
+
+    def body(var):
+      cm.log.append("body(%r)" % var)
+
+    # Note: this does *not* raise an exception (but does log a warning):
+    _py_context_manager.test_py_context_manager(cm, body)
+    self.assertEqual("\n".join(cm.log), NO_EXCEPTION_LOG)
+
+  # Test behavior in unsupported case where __exit__ suppresses exception.
+  def testExitSuppressesException(self):
+    cm = TestContextManager("suppress_exception")
+
+    def body(var):
+      cm.log.append("body(%r)" % var)
+      raise ValueError("Foo")
+
+    with self.assertRaisesRegexp(
+        ValueError, "tensorflow::PyContextManager::Enter does not support "
+        "context managers that suppress exception"):
+      _py_context_manager.test_py_context_manager(cm, body)
+    self.assertRegex("\n".join(cm.log), EXCEPTION_LOG)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4d7b774..47639c9 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -54,6 +54,7 @@
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import tape
+from tensorflow.python.framework import config
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -70,6 +71,7 @@
 from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import control_flow_util_v2
 from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import script_ops
 from tensorflow.python.ops import summary_ops_v2
 from tensorflow.python.ops import variables
@@ -1908,6 +1910,75 @@
   return xla_allow_fallback_impl
 
 
+# The description is just for documentation purposes.
+def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute test with TensorFloat-32 disabled.
+
+  While almost every real-world deep learning model runs fine with
+  TensorFloat-32, many tests use assertAllClose or similar methods.
+  TensorFloat-32 matmuls typically will cause such methods to fail with the
+  default tolerances.
+
+  Args:
+    description: A description used for documentation purposes, describing why
+      the test requires TensorFloat-32 to be disabled.
+
+  Returns:
+    Decorator which runs a test with TensorFloat-32 disabled.
+  """
+
+  def decorator(f):
+
+    @functools.wraps(f)
+    def decorated(self, *args, **kwargs):
+      allowed = config.tensor_float_32_execution_enabled()
+      try:
+        config.enable_tensor_float_32_execution(False)
+        f(self, *args, **kwargs)
+      finally:
+        config.enable_tensor_float_32_execution(allowed)
+
+    return decorated
+
+  return decorator
+
+
+# The description is just for documentation purposes.
+def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute all tests in a class with TensorFloat-32 disabled."""
+  return for_all_test_methods(run_without_tensor_float_32, description)
+
+
+def matmul_without_tf32(a, b, *args, **kwargs):
+  """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled.
+
+  This effectively runs matmul without TensorFloat-32. It should only be used in
+  tests when verifying some other op or functions works correctly, e.g. to test
+  `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In
+  such cases, the matmul itself is not being tested so it's OK to run it with
+  higher precision.
+
+  If a matmul itself is being tested, or some other op which uses matmul, use
+  `run_without_tensor_float_32` instead.
+
+  Args:
+    a: First input to tf.linalg.matmul
+    b: Second input to tf.linalg.matmul
+    args: Other positional arguments to tf.linalg.matmul
+    **kwargs: Other keyword arguments to tf.linalg.matmul
+
+  Returns:
+    A tensor with the same type as `a`.
+  """
+  if config.tensor_float_32_execution_enabled() and a.dtype == "float32":
+    a = math_ops.cast(a, "float64")
+    b = math_ops.cast(b, "float64")
+    ret = math_ops.matmul(a, b, *args, **kwargs)
+    return math_ops.cast(ret, a.dtype)
+  else:
+    return math_ops.matmul(a, b, *args, **kwargs)
+
+
 class EagerSessionWarner(object):
 
   def __getattr__(self, attr):
diff --git a/tensorflow/python/grappler/cluster_wrapper.cc b/tensorflow/python/grappler/cluster_wrapper.cc
index aa762cb..dee8e59 100644
--- a/tensorflow/python/grappler/cluster_wrapper.cc
+++ b/tensorflow/python/grappler/cluster_wrapper.cc
@@ -99,7 +99,7 @@
           std::vector<tensorflow::NamedDevice> named_devices;
           for (const auto& s : serialized_named_devices) {
             tensorflow::NamedDevice named_device;
-            if (!named_device.ParseFromString(s)) {
+            if (!named_device.ParseFromString(std::string(s))) {
               throw std::invalid_argument(
                   "The NamedDevice could not be parsed as a valid protocol "
                   "buffer");
@@ -241,7 +241,7 @@
 
   m.def("TF_EstimatePerformance", [](const py::bytes& serialized_device) {
     tensorflow::NamedDevice device;
-    if (!device.ParseFromString(serialized_device)) {
+    if (!device.ParseFromString(std::string(serialized_device))) {
       throw std::invalid_argument(
           "The NamedDevice could not be parsed as a valid protocol buffer");
     }
diff --git a/tensorflow/python/grappler/cost_analyzer_wrapper.cc b/tensorflow/python/grappler/cost_analyzer_wrapper.cc
index ce557b0..4e960bb 100644
--- a/tensorflow/python/grappler/cost_analyzer_wrapper.cc
+++ b/tensorflow/python/grappler/cost_analyzer_wrapper.cc
@@ -32,7 +32,7 @@
         [](const py::bytes& serialized_metagraph, bool per_node_report,
            bool verbose, tensorflow::grappler::Cluster* cluster) -> py::bytes {
           tensorflow::MetaGraphDef metagraph;
-          if (!metagraph.ParseFromString(serialized_metagraph)) {
+          if (!metagraph.ParseFromString(std::string(serialized_metagraph))) {
             return "The MetaGraphDef could not be parsed as a valid protocol "
                    "buffer";
           }
diff --git a/tensorflow/python/grappler/item_wrapper.cc b/tensorflow/python/grappler/item_wrapper.cc
index e55b468..3b29392 100644
--- a/tensorflow/python/grappler/item_wrapper.cc
+++ b/tensorflow/python/grappler/item_wrapper.cc
@@ -129,7 +129,7 @@
         [](const py::bytes& serialized_metagraph, bool ignore_colocation,
            bool ignore_user_placement) -> tensorflow::grappler::GrapplerItem* {
           tensorflow::MetaGraphDef metagraph;
-          if (!metagraph.ParseFromString(serialized_metagraph)) {
+          if (!metagraph.ParseFromString(std::string(serialized_metagraph))) {
             throw std::invalid_argument(
                 "The MetaGraphDef could not be parsed as a valid protocol "
                 "buffer");
diff --git a/tensorflow/python/grappler/model_analyzer_wrapper.cc b/tensorflow/python/grappler/model_analyzer_wrapper.cc
index 47d1ec8..68740ca 100644
--- a/tensorflow/python/grappler/model_analyzer_wrapper.cc
+++ b/tensorflow/python/grappler/model_analyzer_wrapper.cc
@@ -29,7 +29,7 @@
         [](const py::bytes& serialized_metagraph, bool assume_valid_feeds,
            bool debug) -> py::bytes {
           tensorflow::MetaGraphDef metagraph;
-          if (!metagraph.ParseFromString(serialized_metagraph)) {
+          if (!metagraph.ParseFromString(std::string(serialized_metagraph))) {
             return "The MetaGraphDef could not be parsed as a valid protocol "
                    "buffer";
           }
diff --git a/tensorflow/python/grappler/tf_optimizer_wrapper.cc b/tensorflow/python/grappler/tf_optimizer_wrapper.cc
index 14336a0..32446a6 100644
--- a/tensorflow/python/grappler/tf_optimizer_wrapper.cc
+++ b/tensorflow/python/grappler/tf_optimizer_wrapper.cc
@@ -66,12 +66,13 @@
          const std::string& graph_id,
          bool strip_default_attributes) -> py::bytes {
         tensorflow::ConfigProto config_proto;
-        if (!config_proto.ParseFromString(serialized_config_proto)) {
+        if (!config_proto.ParseFromString(
+                std::string(serialized_config_proto))) {
           throw std::invalid_argument(
               "The ConfigProto could not be parsed as a valid protocol buffer");
         }
         tensorflow::MetaGraphDef metagraph;
-        if (!metagraph.ParseFromString(serialized_metagraph)) {
+        if (!metagraph.ParseFromString(std::string(serialized_metagraph))) {
           throw std::invalid_argument(
               "The MetaGraphDef could not be parsed as a valid protocol "
               "buffer");
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index d8eff0f..95a192a 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -307,6 +307,7 @@
     deps = [
         ":backend",
         ":models",
+        "//tensorflow/python:config",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index fe0bf59..119851f 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -499,8 +499,10 @@
 def deserialize(name, custom_objects=None):
   """Returns activation function given a string identifier.
 
-  Arguments:
-      x : String identifier.
+  Args:
+    name: The name of the activation function.
+    custom_objects: Optional `{function_name: function_obj}`
+      dictionary listing user-provided activation functions.
 
   Returns:
       Corresponding activation function.
@@ -516,11 +518,6 @@
   ...
   ValueError: Unknown activation function:abcd
 
-  Args:
-    name: The name of the activation function.
-    custom_objects: Optional `{function_name: function_obj}`
-      dictionary listing user-provided activation functions.
-
   Raises:
       ValueError: `Unknown activation function` if the input string does not
       denote any defined Tensorflow activation function.
diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py
index 1302598..03f5754 100644
--- a/tensorflow/python/keras/applications/densenet.py
+++ b/tensorflow/python/keras/applications/densenet.py
@@ -192,7 +192,7 @@
     ValueError: if `classifier_activation` is not `softmax` or `None` when
       using a pretrained top layer.
   """
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/efficientnet.py b/tensorflow/python/keras/applications/efficientnet.py
index b791bbc..1e75d32 100644
--- a/tensorflow/python/keras/applications/efficientnet.py
+++ b/tensorflow/python/keras/applications/efficientnet.py
@@ -269,7 +269,7 @@
   if blocks_args == 'default':
     blocks_args = DEFAULT_BLOCKS_ARGS
 
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py
index 3bf2969..2ac2138 100644
--- a/tensorflow/python/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/applications/inception_resnet_v2.py
@@ -112,7 +112,7 @@
     layers = VersionAwareLayers()
   if kwargs:
     raise ValueError('Unknown argument(s): %s' % (kwargs,))
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py
index 7237cf5..59a8698 100644
--- a/tensorflow/python/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/applications/inception_v3.py
@@ -108,7 +108,7 @@
     ValueError: if `classifier_activation` is not `softmax` or `None` when
       using a pretrained top layer.
   """
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index f294917..b628f2c 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -164,7 +164,7 @@
     layers = VersionAwareLayers()
   if kwargs:
     raise ValueError('Unknown argument(s): %s' % (kwargs,))
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py
index 77392d4..13867a4 100644
--- a/tensorflow/python/keras/applications/mobilenet_v2.py
+++ b/tensorflow/python/keras/applications/mobilenet_v2.py
@@ -180,7 +180,7 @@
     layers = VersionAwareLayers()
   if kwargs:
     raise ValueError('Unknown argument(s): %s' % (kwargs,))
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/mobilenet_v3.py b/tensorflow/python/keras/applications/mobilenet_v3.py
index 44ba6fd..a83716e 100644
--- a/tensorflow/python/keras/applications/mobilenet_v3.py
+++ b/tensorflow/python/keras/applications/mobilenet_v3.py
@@ -158,7 +158,7 @@
                 pooling=None,
                 dropout_rate=0.2,
                 classifier_activation='softmax'):
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py
index 4f71165..3f14646 100644
--- a/tensorflow/python/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/applications/nasnet.py
@@ -150,7 +150,7 @@
     ValueError: if `classifier_activation` is not `softmax` or `None` when
       using a pretrained top layer.
   """
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/resnet.py b/tensorflow/python/keras/applications/resnet.py
index 6131039..720704f 100644
--- a/tensorflow/python/keras/applications/resnet.py
+++ b/tensorflow/python/keras/applications/resnet.py
@@ -137,7 +137,7 @@
     layers = VersionAwareLayers()
   if kwargs:
     raise ValueError('Unknown argument(s): %s' % (kwargs,))
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py
index 0d50899..9b46ca3 100644
--- a/tensorflow/python/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/applications/vgg16.py
@@ -113,7 +113,7 @@
     ValueError: if `classifier_activation` is not `softmax` or `None` when
       using a pretrained top layer.
   """
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py
index c033f57..54dc62c 100644
--- a/tensorflow/python/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/applications/vgg19.py
@@ -113,7 +113,7 @@
     ValueError: if `classifier_activation` is not `softmax` or `None` when
       using a pretrained top layer.
   """
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py
index b954f77..80027b2 100644
--- a/tensorflow/python/keras/applications/xception.py
+++ b/tensorflow/python/keras/applications/xception.py
@@ -113,7 +113,7 @@
     ValueError: if `classifier_activation` is not `softmax` or `None` when
       using a pretrained top layer.
   """
-  if not (weights in {'imagenet', None} or file_io.file_exists(weights)):
+  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
     raise ValueError('The `weights` argument should be either '
                      '`None` (random initialization), `imagenet` '
                      '(pre-training on ImageNet), '
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index bde1739..28bbe42 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -84,6 +84,7 @@
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import keras_export
+from tensorflow.tools.docs import doc_controls
 
 py_all = all
 py_sum = sum
@@ -163,6 +164,7 @@
 
 
 @keras_export('keras.backend.backend')
+@doc_controls.do_not_generate_docs
 def backend():
   """Publicly accessible method for determining the current backend.
 
@@ -176,6 +178,7 @@
 
 @keras_export('keras.backend.cast_to_floatx')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def cast_to_floatx(x):
   """Cast a Numpy array to the default Keras float type.
 
@@ -310,6 +313,7 @@
 
 
 @keras_export('keras.backend.manual_variable_initialization')
+@doc_controls.do_not_generate_docs
 def manual_variable_initialization(value):
   """Sets the manual variable initialization flag.
 
@@ -327,6 +331,7 @@
 
 
 @keras_export('keras.backend.learning_phase')
+@doc_controls.do_not_generate_docs
 def learning_phase():
   """Returns the learning phase flag.
 
@@ -395,6 +400,7 @@
 
 
 @keras_export('keras.backend.set_learning_phase')
+@doc_controls.do_not_generate_docs
 def set_learning_phase(value):
   """Sets the learning phase to a fixed value.
 
@@ -461,6 +467,7 @@
 
 @keras_export('keras.backend.learning_phase_scope')
 @tf_contextlib.contextmanager
+@doc_controls.do_not_generate_docs
 def learning_phase_scope(value):
   """Provides a scope within which the learning phase is equal to `value`.
 
@@ -833,10 +840,11 @@
   Returns:
       A tensor.
   """
-  return ops.convert_to_tensor_v2(x, dtype=dtype)
+  return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
 
 
 @keras_export('keras.backend.is_sparse')
+@doc_controls.do_not_generate_docs
 def is_sparse(tensor):
   """Returns whether a tensor is a sparse tensor.
 
@@ -865,6 +873,7 @@
 
 @keras_export('keras.backend.to_dense')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def to_dense(tensor):
   """Converts a sparse tensor into a dense tensor and returns it.
 
@@ -892,6 +901,7 @@
 
 
 @keras_export('keras.backend.name_scope', v1=[])
+@doc_controls.do_not_generate_docs
 def name_scope(name):
   """A context manager for use when defining a Python op.
 
@@ -923,6 +933,7 @@
 
 
 @keras_export('keras.backend.variable')
+@doc_controls.do_not_generate_docs
 def variable(value, dtype=None, name=None, constraint=None):
   """Instantiates a variable and returns it.
 
@@ -1074,6 +1085,7 @@
 
 @keras_export('keras.backend.constant')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def constant(value, dtype=None, shape=None, name=None):
   """Creates a constant tensor.
 
@@ -1147,6 +1159,7 @@
 
 
 @keras_export('keras.backend.placeholder')
+@doc_controls.do_not_generate_docs
 def placeholder(shape=None,
                 ndim=None,
                 dtype=None,
@@ -1265,6 +1278,7 @@
 
 @keras_export('keras.backend.shape')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def shape(x):
   """Returns the symbolic shape of a tensor or variable.
 
@@ -1289,6 +1303,7 @@
 
 
 @keras_export('keras.backend.int_shape')
+@doc_controls.do_not_generate_docs
 def int_shape(x):
   """Returns the shape of tensor or variable as a tuple of int or None entries.
 
@@ -1319,6 +1334,7 @@
 
 
 @keras_export('keras.backend.ndim')
+@doc_controls.do_not_generate_docs
 def ndim(x):
   """Returns the number of axes in a tensor, as an integer.
 
@@ -1348,6 +1364,7 @@
 
 @keras_export('keras.backend.dtype')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def dtype(x):
   """Returns the dtype of a Keras tensor or variable, as a string.
 
@@ -1380,6 +1397,7 @@
 
 
 @keras_export('keras.backend.eval')
+@doc_controls.do_not_generate_docs
 def eval(x):
   """Evaluates the value of a variable.
 
@@ -1402,6 +1420,7 @@
 
 
 @keras_export('keras.backend.zeros')
+@doc_controls.do_not_generate_docs
 def zeros(shape, dtype=None, name=None):
   """Instantiates an all-zeros variable and returns it.
 
@@ -1447,6 +1466,7 @@
 
 @keras_export('keras.backend.ones')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def ones(shape, dtype=None, name=None):
   """Instantiates an all-ones variable and returns it.
 
@@ -1482,6 +1502,7 @@
 
 @keras_export('keras.backend.eye')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def eye(size, dtype=None, name=None):
   """Instantiate an identity matrix and returns it.
 
@@ -1511,6 +1532,7 @@
 
 
 @keras_export('keras.backend.zeros_like')
+@doc_controls.do_not_generate_docs
 def zeros_like(x, dtype=None, name=None):
   """Instantiates an all-zeros variable of the same shape as another tensor.
 
@@ -1539,6 +1561,7 @@
 
 @keras_export('keras.backend.ones_like')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def ones_like(x, dtype=None, name=None):
   """Instantiates an all-ones variable of the same shape as another tensor.
 
@@ -1577,6 +1600,7 @@
 
 
 @keras_export('keras.backend.random_uniform_variable')
+@doc_controls.do_not_generate_docs
 def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
   """Instantiates a variable with values drawn from a uniform distribution.
 
@@ -1611,6 +1635,7 @@
 
 
 @keras_export('keras.backend.random_normal_variable')
+@doc_controls.do_not_generate_docs
 def random_normal_variable(shape, mean, scale, dtype=None, name=None,
                            seed=None):
   """Instantiates a variable with values drawn from a normal distribution.
@@ -1646,6 +1671,7 @@
 
 
 @keras_export('keras.backend.count_params')
+@doc_controls.do_not_generate_docs
 def count_params(x):
   """Returns the static number of elements in a variable or tensor.
 
@@ -1670,6 +1696,7 @@
 
 @keras_export('keras.backend.cast')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def cast(x, dtype):
   """Casts a tensor to a different dtype and returns it.
 
@@ -1701,11 +1728,13 @@
 
 
 @keras_export('keras.backend.update')
+@doc_controls.do_not_generate_docs
 def update(x, new_x):
   return state_ops.assign(x, new_x)
 
 
 @keras_export('keras.backend.update_add')
+@doc_controls.do_not_generate_docs
 def update_add(x, increment):
   """Update the value of `x` by adding `increment`.
 
@@ -1720,6 +1749,7 @@
 
 
 @keras_export('keras.backend.update_sub')
+@doc_controls.do_not_generate_docs
 def update_sub(x, decrement):
   """Update the value of `x` by subtracting `decrement`.
 
@@ -1734,6 +1764,7 @@
 
 
 @keras_export('keras.backend.moving_average_update')
+@doc_controls.do_not_generate_docs
 def moving_average_update(x, value, momentum):
   """Compute the exponential moving average of a value.
 
@@ -1781,6 +1812,7 @@
 
 @keras_export('keras.backend.dot')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def dot(x, y):
   """Multiplies 2 tensors (and/or variables) and returns a tensor.
 
@@ -1842,6 +1874,7 @@
 
 @keras_export('keras.backend.batch_dot')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def batch_dot(x, y, axes=None):
   """Batchwise dot product.
 
@@ -2031,6 +2064,7 @@
 
 @keras_export('keras.backend.transpose')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def transpose(x):
   """Transposes a tensor and returns it.
 
@@ -2063,6 +2097,7 @@
 
 @keras_export('keras.backend.gather')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def gather(reference, indices):
   """Retrieves the elements of indices `indices` in the tensor `reference`.
 
@@ -2099,6 +2134,7 @@
 
 @keras_export('keras.backend.max')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def max(x, axis=None, keepdims=False):
   """Maximum value in a tensor.
 
@@ -2118,6 +2154,7 @@
 
 @keras_export('keras.backend.min')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def min(x, axis=None, keepdims=False):
   """Minimum value in a tensor.
 
@@ -2137,6 +2174,7 @@
 
 @keras_export('keras.backend.sum')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def sum(x, axis=None, keepdims=False):
   """Sum of the values in a tensor, alongside the specified axis.
 
@@ -2156,6 +2194,7 @@
 
 @keras_export('keras.backend.prod')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def prod(x, axis=None, keepdims=False):
   """Multiplies the values in a tensor, alongside the specified axis.
 
@@ -2175,6 +2214,7 @@
 
 @keras_export('keras.backend.cumsum')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def cumsum(x, axis=0):
   """Cumulative sum of the values in a tensor, alongside the specified axis.
 
@@ -2190,6 +2230,7 @@
 
 @keras_export('keras.backend.cumprod')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def cumprod(x, axis=0):
   """Cumulative product of the values in a tensor, alongside the specified axis.
 
@@ -2204,6 +2245,7 @@
 
 
 @keras_export('keras.backend.var')
+@doc_controls.do_not_generate_docs
 def var(x, axis=None, keepdims=False):
   """Variance of a tensor, alongside the specified axis.
 
@@ -2225,6 +2267,7 @@
 
 @keras_export('keras.backend.std')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def std(x, axis=None, keepdims=False):
   """Standard deviation of a tensor, alongside the specified axis.
 
@@ -2252,6 +2295,7 @@
 
 @keras_export('keras.backend.mean')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def mean(x, axis=None, keepdims=False):
   """Mean of a tensor, alongside the specified axis.
 
@@ -2273,6 +2317,7 @@
 
 @keras_export('keras.backend.any')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def any(x, axis=None, keepdims=False):
   """Bitwise reduction (logical OR).
 
@@ -2290,6 +2335,7 @@
 
 @keras_export('keras.backend.all')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def all(x, axis=None, keepdims=False):
   """Bitwise reduction (logical AND).
 
@@ -2307,6 +2353,7 @@
 
 @keras_export('keras.backend.argmax')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def argmax(x, axis=-1):
   """Returns the index of the maximum value along an axis.
 
@@ -2322,6 +2369,7 @@
 
 @keras_export('keras.backend.argmin')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def argmin(x, axis=-1):
   """Returns the index of the minimum value along an axis.
 
@@ -2337,6 +2385,7 @@
 
 @keras_export('keras.backend.square')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def square(x):
   """Element-wise square.
 
@@ -2351,6 +2400,7 @@
 
 @keras_export('keras.backend.abs')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def abs(x):
   """Element-wise absolute value.
 
@@ -2365,6 +2415,7 @@
 
 @keras_export('keras.backend.sqrt')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def sqrt(x):
   """Element-wise square root.
 
@@ -2382,6 +2433,7 @@
 
 @keras_export('keras.backend.exp')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def exp(x):
   """Element-wise exponential.
 
@@ -2396,6 +2448,7 @@
 
 @keras_export('keras.backend.log')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def log(x):
   """Element-wise log.
 
@@ -2431,6 +2484,7 @@
 
 @keras_export('keras.backend.round')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def round(x):
   """Element-wise rounding to the closest integer.
 
@@ -2447,6 +2501,7 @@
 
 @keras_export('keras.backend.sign')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def sign(x):
   """Element-wise sign.
 
@@ -2461,6 +2516,7 @@
 
 @keras_export('keras.backend.pow')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def pow(x, a):
   """Element-wise exponentiation.
 
@@ -2476,6 +2532,7 @@
 
 @keras_export('keras.backend.clip')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def clip(x, min_value, max_value):
   """Element-wise value clipping.
 
@@ -2500,6 +2557,7 @@
 
 @keras_export('keras.backend.equal')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def equal(x, y):
   """Element-wise equality between two tensors.
 
@@ -2515,6 +2573,7 @@
 
 @keras_export('keras.backend.not_equal')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def not_equal(x, y):
   """Element-wise inequality between two tensors.
 
@@ -2530,6 +2589,7 @@
 
 @keras_export('keras.backend.greater')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def greater(x, y):
   """Element-wise truth value of (x > y).
 
@@ -2545,6 +2605,7 @@
 
 @keras_export('keras.backend.greater_equal')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def greater_equal(x, y):
   """Element-wise truth value of (x >= y).
 
@@ -2560,6 +2621,7 @@
 
 @keras_export('keras.backend.less')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def less(x, y):
   """Element-wise truth value of (x < y).
 
@@ -2575,6 +2637,7 @@
 
 @keras_export('keras.backend.less_equal')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def less_equal(x, y):
   """Element-wise truth value of (x <= y).
 
@@ -2590,6 +2653,7 @@
 
 @keras_export('keras.backend.maximum')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def maximum(x, y):
   """Element-wise maximum of two tensors.
 
@@ -2615,6 +2679,7 @@
 
 @keras_export('keras.backend.minimum')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def minimum(x, y):
   """Element-wise minimum of two tensors.
 
@@ -2630,6 +2695,7 @@
 
 @keras_export('keras.backend.sin')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def sin(x):
   """Computes sin of x element-wise.
 
@@ -2644,6 +2710,7 @@
 
 @keras_export('keras.backend.cos')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def cos(x):
   """Computes cos of x element-wise.
 
@@ -2759,6 +2826,7 @@
 
 
 @keras_export('keras.backend.normalize_batch_in_training')
+@doc_controls.do_not_generate_docs
 def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
   """Computes mean and std for batch then apply batch_normalization on batch.
 
@@ -2790,6 +2858,7 @@
 
 @keras_export('keras.backend.batch_normalization')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
   """Applies batch normalization on x given mean, var, beta and gamma.
 
@@ -2853,6 +2922,7 @@
 
 @keras_export('keras.backend.concatenate')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def concatenate(tensors, axis=-1):
   """Concatenates a list of tensors alongside the specified axis.
 
@@ -2891,6 +2961,7 @@
 
 @keras_export('keras.backend.reshape')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def reshape(x, shape):
   """Reshapes a tensor to the specified shape.
 
@@ -2921,6 +2992,7 @@
 
 @keras_export('keras.backend.permute_dimensions')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def permute_dimensions(x, pattern):
   """Permutes axes in a tensor.
 
@@ -2953,6 +3025,7 @@
 
 @keras_export('keras.backend.resize_images')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def resize_images(x, height_factor, width_factor, data_format,
                   interpolation='nearest'):
   """Resizes the images contained in a 4D tensor.
@@ -3017,6 +3090,7 @@
 
 @keras_export('keras.backend.resize_volumes')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
   """Resizes the volume contained in a 5D tensor.
 
@@ -3050,6 +3124,7 @@
 
 @keras_export('keras.backend.repeat_elements')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def repeat_elements(x, rep, axis):
   """Repeats the elements of a tensor along an axis, like `np.repeat`.
 
@@ -3112,6 +3187,7 @@
 
 @keras_export('keras.backend.repeat')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def repeat(x, n):
   """Repeats a 2D tensor.
 
@@ -3148,6 +3224,7 @@
 
 @keras_export('keras.backend.arange')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def arange(start, stop=None, step=1, dtype='int32'):
   """Creates a 1D tensor containing a sequence of integers.
 
@@ -3187,6 +3264,7 @@
 
 @keras_export('keras.backend.tile')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def tile(x, n):
   """Creates a tensor by tiling `x` by `n`.
 
@@ -3205,6 +3283,7 @@
 
 @keras_export('keras.backend.flatten')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def flatten(x):
   """Flatten a tensor.
 
@@ -3231,6 +3310,7 @@
 
 @keras_export('keras.backend.batch_flatten')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def batch_flatten(x):
   """Turn a nD tensor into a 2D tensor with same 0th dimension.
 
@@ -3257,6 +3337,7 @@
 
 @keras_export('keras.backend.expand_dims')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def expand_dims(x, axis=-1):
   """Adds a 1-sized dimension at index "axis".
 
@@ -3272,6 +3353,7 @@
 
 @keras_export('keras.backend.squeeze')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def squeeze(x, axis):
   """Removes a 1-dimension from the tensor at index "axis".
 
@@ -3287,6 +3369,7 @@
 
 @keras_export('keras.backend.temporal_padding')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def temporal_padding(x, padding=(1, 1)):
   """Pads the middle dimension of a 3D tensor.
 
@@ -3305,6 +3388,7 @@
 
 @keras_export('keras.backend.spatial_2d_padding')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
   """Pads the 2nd and 3rd dimensions of a 4D tensor.
 
@@ -3337,6 +3421,7 @@
 
 @keras_export('keras.backend.spatial_3d_padding')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
   """Pads 5D tensor with zeros along the depth, height, width dimensions.
 
@@ -3382,6 +3467,7 @@
 
 @keras_export('keras.backend.stack')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def stack(x, axis=0):
   """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
 
@@ -3409,6 +3495,7 @@
 
 @keras_export('keras.backend.one_hot')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def one_hot(indices, num_classes):
   """Computes the one-hot representation of an integer tensor.
 
@@ -3429,6 +3516,7 @@
 
 @keras_export('keras.backend.reverse')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def reverse(x, axes):
   """Reverse a tensor along the specified axes.
 
@@ -3475,6 +3563,7 @@
 
 
 @keras_export('keras.backend.get_value')
+@doc_controls.do_not_generate_docs
 def get_value(x):
   """Returns the value of a variable.
 
@@ -3510,6 +3599,7 @@
 
 @keras_export('keras.backend.batch_get_value')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def batch_get_value(tensors):
   """Returns the value of more than one tensor variable.
 
@@ -3533,6 +3623,7 @@
 
 
 @keras_export('keras.backend.set_value')
+@doc_controls.do_not_generate_docs
 def set_value(x, value):
   """Sets the value of a variable, from a Numpy array.
 
@@ -3572,6 +3663,7 @@
 
 @keras_export('keras.backend.batch_set_value')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def batch_set_value(tuples):
   """Sets the values of many tensor variables at once.
 
@@ -3615,6 +3707,7 @@
 
 @keras_export('keras.backend.print_tensor')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def print_tensor(x, message=''):
   """Prints `message` and the tensor value when evaluated.
 
@@ -3916,6 +4009,7 @@
 
 
 @keras_export('keras.backend.function')
+@doc_controls.do_not_generate_docs
 def function(inputs, outputs, updates=None, name=None, **kwargs):
   """Instantiates a Keras function.
 
@@ -3963,6 +4057,7 @@
 
 
 @keras_export('keras.backend.gradients')
+@doc_controls.do_not_generate_docs
 def gradients(loss, variables):
   """Returns the gradients of `loss` w.r.t. `variables`.
 
@@ -3979,6 +4074,7 @@
 
 @keras_export('keras.backend.stop_gradient')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def stop_gradient(variables):
   """Returns `variables` but with zero gradient w.r.t. every other variable.
 
@@ -4396,6 +4492,7 @@
 
 @keras_export('keras.backend.switch')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def switch(condition, then_expression, else_expression):
   """Switches between two operations depending on a scalar value.
 
@@ -4460,6 +4557,7 @@
 
 
 @keras_export('keras.backend.in_train_phase')
+@doc_controls.do_not_generate_docs
 def in_train_phase(x, alt, training=None):
   """Selects `x` in train phase, and `alt` otherwise.
 
@@ -4505,6 +4603,7 @@
 
 
 @keras_export('keras.backend.in_test_phase')
+@doc_controls.do_not_generate_docs
 def in_test_phase(x, alt, training=None):
   """Selects `x` in test phase, and `alt` otherwise.
 
@@ -4530,6 +4629,7 @@
 
 @keras_export('keras.backend.relu')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def relu(x, alpha=0., max_value=None, threshold=0):
   """Rectified linear unit.
 
@@ -4587,6 +4687,7 @@
 
 @keras_export('keras.backend.elu')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def elu(x, alpha=1.):
   """Exponential linear unit.
 
@@ -4606,6 +4707,7 @@
 
 @keras_export('keras.backend.softmax')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def softmax(x, axis=-1):
   """Softmax of a tensor.
 
@@ -4622,6 +4724,7 @@
 
 @keras_export('keras.backend.softplus')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def softplus(x):
   """Softplus of a tensor.
 
@@ -4636,6 +4739,7 @@
 
 @keras_export('keras.backend.softsign')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def softsign(x):
   """Softsign of a tensor.
 
@@ -4650,6 +4754,7 @@
 
 @keras_export('keras.backend.categorical_crossentropy')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def categorical_crossentropy(target, output, from_logits=False, axis=-1):
   """Categorical crossentropy between an output tensor and a target tensor.
 
@@ -4661,7 +4766,7 @@
       from_logits: Boolean, whether `output` is the
           result of a softmax, or is a tensor of logits.
       axis: Int specifying the channels axis. `axis=-1` corresponds to data
-          format `channels_last', and `axis=1` corresponds to data format
+          format `channels_last`, and `axis=1` corresponds to data format
           `channels_first`.
 
   Returns:
@@ -4692,8 +4797,8 @@
   [0. 0. 0.]
 
   """
-  target = ops.convert_to_tensor_v2(target)
-  output = ops.convert_to_tensor_v2(output)
+  target = ops.convert_to_tensor_v2_with_dispatch(target)
+  output = ops.convert_to_tensor_v2_with_dispatch(output)
 
   target.shape.assert_is_compatible_with(output.shape)
   if from_logits:
@@ -4721,6 +4826,7 @@
 
 @keras_export('keras.backend.sparse_categorical_crossentropy')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
   """Categorical crossentropy with integer targets.
 
@@ -4732,7 +4838,7 @@
       from_logits: Boolean, whether `output` is the
           result of a softmax, or is a tensor of logits.
       axis: Int specifying the channels axis. `axis=-1` corresponds to data
-          format `channels_last', and `axis=1` corresponds to data format
+          format `channels_last`, and `axis=1` corresponds to data format
           `channels_first`.
 
   Returns:
@@ -4741,8 +4847,8 @@
   Raises:
       ValueError: if `axis` is neither -1 nor one of the axes of `output`.
   """
-  target = ops.convert_to_tensor_v2(target)
-  output = ops.convert_to_tensor_v2(output)
+  target = ops.convert_to_tensor_v2_with_dispatch(target)
+  output = ops.convert_to_tensor_v2_with_dispatch(output)
 
   if (not from_logits and
       not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
@@ -4805,6 +4911,7 @@
 
 @keras_export('keras.backend.binary_crossentropy')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def binary_crossentropy(target, output, from_logits=False):
   """Binary crossentropy between an output tensor and a target tensor.
 
@@ -4818,8 +4925,8 @@
   Returns:
       A tensor.
   """
-  target = ops.convert_to_tensor_v2(target)
-  output = ops.convert_to_tensor_v2(output)
+  target = ops.convert_to_tensor_v2_with_dispatch(target)
+  output = ops.convert_to_tensor_v2_with_dispatch(output)
 
   if from_logits:
     return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
@@ -4844,6 +4951,7 @@
 
 @keras_export('keras.backend.sigmoid')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def sigmoid(x):
   """Element-wise sigmoid.
 
@@ -4858,6 +4966,7 @@
 
 @keras_export('keras.backend.hard_sigmoid')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def hard_sigmoid(x):
   """Segment-wise linear approximation of sigmoid.
 
@@ -4881,6 +4990,7 @@
 
 @keras_export('keras.backend.tanh')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def tanh(x):
   """Element-wise tanh.
 
@@ -4895,6 +5005,7 @@
 
 @keras_export('keras.backend.dropout')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def dropout(x, level, noise_shape=None, seed=None):
   """Sets entries in `x` to zero at random, while scaling the entire tensor.
 
@@ -4916,6 +5027,7 @@
 
 @keras_export('keras.backend.l2_normalize')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def l2_normalize(x, axis=None):
   """Normalizes a tensor wrt the L2 norm alongside the specified axis.
 
@@ -4931,6 +5043,7 @@
 
 @keras_export('keras.backend.in_top_k')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def in_top_k(predictions, targets, k):
   """Returns whether the `targets` are in the top `k` `predictions`.
 
@@ -5034,6 +5147,7 @@
 
 @keras_export('keras.backend.conv1d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def conv1d(x,
            kernel,
            strides=1,
@@ -5085,6 +5199,7 @@
 
 @keras_export('keras.backend.conv2d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def conv2d(x,
            kernel,
            strides=(1, 1),
@@ -5129,6 +5244,7 @@
 
 @keras_export('keras.backend.conv2d_transpose')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def conv2d_transpose(x,
                      kernel,
                      output_shape,
@@ -5270,6 +5386,7 @@
 
 @keras_export('keras.backend.separable_conv2d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def separable_conv2d(x,
                      depthwise_kernel,
                      pointwise_kernel,
@@ -5328,6 +5445,7 @@
 
 @keras_export('keras.backend.depthwise_conv2d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def depthwise_conv2d(x,
                      depthwise_kernel,
                      strides=(1, 1),
@@ -5378,6 +5496,7 @@
 
 @keras_export('keras.backend.conv3d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def conv3d(x,
            kernel,
            strides=(1, 1, 1),
@@ -5481,6 +5600,7 @@
 
 @keras_export('keras.backend.pool2d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def pool2d(x,
            pool_size,
            strides=(1, 1),
@@ -5541,6 +5661,7 @@
 
 @keras_export('keras.backend.pool3d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def pool3d(x,
            pool_size,
            strides=(1, 1, 1),
@@ -5672,6 +5793,7 @@
 
 @keras_export('keras.backend.local_conv1d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
   """Apply 1D conv with un-shared weights.
 
@@ -5708,6 +5830,7 @@
 
 @keras_export('keras.backend.local_conv2d')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def local_conv2d(inputs,
                  kernel,
                  kernel_size,
@@ -5750,6 +5873,7 @@
 
 @keras_export('keras.backend.bias_add')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def bias_add(x, bias, data_format=None):
   """Adds a bias vector to a tensor.
 
@@ -5795,6 +5919,7 @@
 
 @keras_export('keras.backend.random_normal')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
   """Returns a tensor with normal distribution of values.
 
@@ -5832,6 +5957,7 @@
 
 @keras_export('keras.backend.random_uniform')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
   """Returns a tensor with uniform distribution of values.
 
@@ -5865,6 +5991,7 @@
 
 @keras_export('keras.backend.random_binomial')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def random_binomial(shape, p=0.0, dtype=None, seed=None):
   """Returns a tensor with random binomial distribution of values.
 
@@ -5898,6 +6025,7 @@
 
 @keras_export('keras.backend.random_bernoulli')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def random_bernoulli(shape, p=0.0, dtype=None, seed=None):
   """Returns a tensor with random bernoulli distribution of values.
 
@@ -5921,6 +6049,7 @@
 
 @keras_export('keras.backend.truncated_normal')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
   """Returns a tensor with truncated random normal distribution of values.
 
@@ -5956,6 +6085,7 @@
 
 @keras_export('keras.backend.ctc_label_dense_to_sparse')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def ctc_label_dense_to_sparse(labels, label_lengths):
   """Converts CTC labels from dense to sparse.
 
@@ -6003,6 +6133,7 @@
 
 @keras_export('keras.backend.ctc_batch_cost')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def ctc_batch_cost(y_true, y_pred, input_length, label_length):
   """Runs CTC loss algorithm on each batch element.
 
@@ -6036,6 +6167,7 @@
 
 @keras_export('keras.backend.ctc_decode')
 @dispatch.add_dispatch_support
+@doc_controls.do_not_generate_docs
 def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
   """Decodes the output of a softmax.
 
@@ -6092,6 +6224,7 @@
 
 
 @keras_export('keras.backend.map_fn')
+@doc_controls.do_not_generate_docs
 def map_fn(fn, elems, name=None, dtype=None):
   """Map the function fn over the elements elems and return the outputs.
 
@@ -6108,6 +6241,7 @@
 
 
 @keras_export('keras.backend.foldl')
+@doc_controls.do_not_generate_docs
 def foldl(fn, elems, initializer=None, name=None):
   """Reduce elems using fn to combine them from left to right.
 
@@ -6125,6 +6259,7 @@
 
 
 @keras_export('keras.backend.foldr')
+@doc_controls.do_not_generate_docs
 def foldr(fn, elems, initializer=None, name=None):
   """Reduce elems using fn to combine them from right to left.
 
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 2e0274a..bbbed2f 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -491,7 +491,7 @@
                                      input_shape_b=(4, 7))
 
   def test_relu(self):
-    x = ops.convert_to_tensor_v2([[-4, 0], [2, 7]], 'float32')
+    x = ops.convert_to_tensor_v2_with_dispatch([[-4, 0], [2, 7]], 'float32')
 
     # standard relu
     relu_op = backend.relu(x)
@@ -1310,7 +1310,7 @@
     inputs = backend.variable(input_val)
     initial_states = [
         backend.variable(init_state_val),
-        ops.convert_to_tensor_v2(
+        ops.convert_to_tensor_v2_with_dispatch(
             np.concatenate([init_state_val, init_state_val], axis=-1))
     ]
     mask = backend.variable(np_mask)
@@ -1617,9 +1617,11 @@
     p = backend.placeholder()
     o = backend.categorical_crossentropy(t, p)
 
-    t_val = ops.convert_to_tensor_v2([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
-    p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06],
-                                      [.05, .01, .94]])
+    t_val = ops.convert_to_tensor_v2_with_dispatch([[1., 0., 0.], [0., 1., 0.],
+                                                    [0., 0., 1.]])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05],
+                                                    [.05, .89, .06],
+                                                    [.05, .01, .94]])
     f = backend.function([t, p], o)
 
     result = f([t_val, p_val])
@@ -1633,7 +1635,8 @@
     self.assertArrayNear(result, [.105, .065, .111], 1e-3)
 
     # from logits
-    p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.],
+                                                    [2., 3., 5.]])
     o = backend.categorical_crossentropy(t, p, from_logits=True)
     f = backend.function([t, p], o)
 
@@ -1685,9 +1688,10 @@
     p = backend.placeholder()
     o = backend.sparse_categorical_crossentropy(t, p)
 
-    t_val = ops.convert_to_tensor_v2([0, 1, 2])
-    p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06],
-                                      [.05, .01, .94]])
+    t_val = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05],
+                                                    [.05, .89, .06],
+                                                    [.05, .01, .94]])
     f = backend.function([t, p], o)
 
     result = f([t_val, p_val])
@@ -1703,7 +1707,8 @@
       _ = f([t_val, p_val])
 
     # from logits
-    p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.],
+                                                    [2., 3., 5.]])
     o = backend.sparse_categorical_crossentropy(t, p, from_logits=True)
     f = backend.function([t, p], o)
 
@@ -2124,9 +2129,10 @@
     self.assertEqual(backend.eval(tensor), [9.0])
 
   def test_unequal_rank(self):
-    x = ops.convert_to_tensor_v2(
+    x = ops.convert_to_tensor_v2_with_dispatch(
         np.array([[1, 2, 3], [4, 5, 6]]), dtype='float32')
-    y = ops.convert_to_tensor_v2(np.array([1, 2, 3]), dtype='float32')
+    y = ops.convert_to_tensor_v2_with_dispatch(
+        np.array([1, 2, 3]), dtype='float32')
 
     def true_func():
       return x
diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD
index 2252f88..95e88ca 100644
--- a/tensorflow/python/keras/benchmarks/BUILD
+++ b/tensorflow/python/keras/benchmarks/BUILD
@@ -67,6 +67,7 @@
         "no_oss_py38",  # TODO(b/162044699)
     ],
     deps = [
+        ":profiler_lib",
         "//tensorflow:tensorflow_py",
     ],
 )
@@ -76,6 +77,7 @@
     srcs = ["model_components_benchmarks_test.py"],
     python_version = "PY3",
     deps = [
+        ":profiler_lib",
         "//tensorflow:tensorflow_py",
     ],
 )
diff --git a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD
index 25a81cc..66246d8 100644
--- a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD
+++ b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD
@@ -28,6 +28,7 @@
     srcs = ["saved_model_benchmark_util.py"],
     deps = [
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -41,6 +42,7 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -54,6 +56,7 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -67,6 +70,7 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -80,6 +84,7 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -93,6 +98,7 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -106,6 +112,7 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -119,6 +126,7 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
 
@@ -132,5 +140,6 @@
     deps = [
         ":saved_model_benchmark_util",
         "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/benchmarks:profiler_lib",
     ],
 )
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 3469ccb..7b1cc29 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -675,6 +675,10 @@
 
     Subclasses should override for any actions to run.
 
+    Note that if the `steps_per_execution` argument to `compile` in
+    `tf.keras.Model` is set to `N`, this method will only be called every `N`
+    batches.
+
     Arguments:
         batch: Integer, index of batch within the current epoch.
         logs: Dict, contains the return value of `model.train_step`. Typically,
@@ -691,6 +695,10 @@
 
     Subclasses should override for any actions to run.
 
+    Note that if the `steps_per_execution` argument to `compile` in
+    `tf.keras.Model` is set to `N`, this method will only be called every `N`
+    batches.
+
     Arguments:
         batch: Integer, index of batch within the current epoch.
         logs: Dict. Aggregated metric results up until this batch.
@@ -708,6 +716,10 @@
 
     Subclasses should override for any actions to run.
 
+    Note that if the `steps_per_execution` argument to `compile` in
+    `tf.keras.Model` is set to `N`, this method will only be called every `N`
+    batches.
+
     Arguments:
         batch: Integer, index of batch within the current epoch.
         logs: Dict, contains the return value of `model.test_step`. Typically,
@@ -725,6 +737,10 @@
 
     Subclasses should override for any actions to run.
 
+    Note that if the `steps_per_execution` argument to `compile` in
+    `tf.keras.Model` is set to `N`, this method will only be called every `N`
+    batches.
+
     Arguments:
         batch: Integer, index of batch within the current epoch.
         logs: Dict. Aggregated metric results up until this batch.
@@ -737,6 +753,10 @@
 
     Subclasses should override for any actions to run.
 
+    Note that if the `steps_per_execution` argument to `compile` in
+    `tf.keras.Model` is set to `N`, this method will only be called every `N`
+    batches.
+
     Arguments:
         batch: Integer, index of batch within the current epoch.
         logs: Dict, contains the return value of `model.predict_step`,
@@ -751,6 +771,10 @@
 
     Subclasses should override for any actions to run.
 
+    Note that if the `steps_per_execution` argument to `compile` in
+    `tf.keras.Model` is set to `N`, this method will only be called every `N`
+    batches.
+
     Arguments:
         batch: Integer, index of batch within the current epoch.
         logs: Dict. Aggregated metric results up until this batch.
@@ -896,10 +920,15 @@
   """Callback that terminates training when a NaN loss is encountered.
   """
 
+  def __init__(self):
+    super(TerminateOnNaN, self).__init__()
+    self._supports_tf_logs = True
+
   def on_batch_end(self, batch, logs=None):
     logs = logs or {}
     loss = logs.get('loss')
     if loss is not None:
+      loss = tf_utils.to_numpy_or_python_type(loss)
       if np.isnan(loss) or np.isinf(loss):
         print('Batch %d: Invalid loss, terminating training' % (batch))
         self.model.stop_training = True
@@ -1156,7 +1185,7 @@
       save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves
         the model after each epoch. When using integer, the callback saves the
         model at end of this many batches. If the `Model` is compiled with
-        `experimental_steps_per_execution=N`, then the saving criteria will be
+        `steps_per_execution=N`, then the saving criteria will be
         checked every Nth batch. Note that if the saving isn't aligned to
         epochs, the monitored metric may potentially be less reliable (it
         could reflect as little as 1 batch, since the metrics get reset every
@@ -1259,16 +1288,6 @@
       self.save_weights_only = True
 
   def on_train_begin(self, logs=None):
-    # pylint: disable=protected-access
-    if self.model._in_multi_worker_mode:
-      logging.warning(
-          'Automatic model reloading for interrupted job was removed from '
-          'the `ModelCheckpoint` callback in multi-worker mode, please use the '
-          '`keras.callbacks.experimental.BackupAndRestore` callback instead. '
-          'See this tutorial for details: '
-          'https://www.tensorflow.org/tutorials/distribute/'
-          'multi_worker_with_keras#backupandrestore_callback.'
-      )
     if self.load_weights_on_restart:
       filepath_to_load = (
           self._get_most_recently_modified_file_matching_pattern(self.filepath))
@@ -1399,9 +1418,10 @@
   def _checkpoint_exists(self, filepath):
     """Returns whether the checkpoint `filepath` refers to exists."""
     if filepath.endswith('.h5'):
-      return file_io.file_exists(filepath)
-    tf_saved_model_exists = file_io.file_exists(filepath)
-    tf_weights_only_checkpoint_exists = file_io.file_exists(filepath + '.index')
+      return file_io.file_exists_v2(filepath)
+    tf_saved_model_exists = file_io.file_exists_v2(filepath)
+    tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
+        filepath + '.index')
     return tf_saved_model_exists or tf_weights_only_checkpoint_exists
 
   def _get_most_recently_modified_file_matching_pattern(self, pattern):
@@ -1466,7 +1486,7 @@
     n_file_with_latest_mod_time = 0
     file_path_with_largest_file_name = None
 
-    if file_io.file_exists(dir_name):
+    if file_io.file_exists_v2(dir_name):
       for file_name in os.listdir(dir_name):
         # Only consider if `file_name` matches the pattern.
         if re.match(base_name_regex, file_name):
@@ -2416,7 +2436,7 @@
     """Resets wait counter and cooldown counter.
     """
     if self.mode not in ['auto', 'min', 'max']:
-      logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
+      logging.warning('Learning rate reduction mode %s is unknown, '
                       'fallback to auto mode.', self.mode)
       self.mode = 'auto'
     if (self.mode == 'min' or
@@ -2437,7 +2457,7 @@
     logs['lr'] = K.get_value(self.model.optimizer.lr)
     current = logs.get(self.monitor)
     if current is None:
-      logging.warning('Reduce LR on plateau conditioned on metric `%s` '
+      logging.warning('Learning rate reduction is conditioned on metric `%s` '
                       'which is not available. Available metrics are: %s',
                       self.monitor, ','.join(list(logs.keys())))
 
@@ -2505,7 +2525,7 @@
 
   def on_train_begin(self, logs=None):
     if self.append:
-      if file_io.file_exists(self.filename):
+      if file_io.file_exists_v2(self.filename):
         with open(self.filename, 'r' + self.file_flags) as f:
           self.append_header = not bool(len(f.readline()))
       mode = 'a'
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 9fd8bf8..1eaa3dd 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -935,7 +935,7 @@
                                            verbose=0)
 
     with context.eager_mode():
-      tensor = ops.convert_to_tensor(1.)
+      tensor = ops.convert_to_tensor_v2_with_dispatch(1.)
 
     def mock_numpy():
       raise RuntimeError(
@@ -975,7 +975,7 @@
                                            verbose=2)
 
     with context.eager_mode():
-      tensor = ops.convert_to_tensor(1.)
+      tensor = ops.convert_to_tensor_v2_with_dispatch(1.)
 
     def mock_numpy():
       raise RuntimeError(
@@ -2193,7 +2193,7 @@
                                            steps=100,
                                            verbose=0)
 
-    tensor = ops.convert_to_tensor(1.)
+    tensor = ops.convert_to_tensor_v2_with_dispatch(1.)
 
     def mock_numpy():
       raise RuntimeError(
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index e116ba9..82f4dcb 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -245,7 +245,6 @@
     main = "custom_training_loop_models_test.py",
     tags = [
         "multi_and_single_gpu",
-        "no_cuda11",
     ],
     tpu_tags = [
         "no_oss",  # b/153615544.
@@ -417,7 +416,7 @@
     srcs = ["keras_embedding_model_correctness_test.py"],
     full_precision = True,
     main = "keras_embedding_model_correctness_test.py",
-    shard_count = 4,
+    shard_count = 8,
     tags = [
         "multi_and_single_gpu",
         "no_windows_gpu",
@@ -469,6 +468,21 @@
 )
 
 distribute_py_test(
+    name = "keras_models_test",
+    srcs = ["keras_models_test.py"],
+    main = "keras_models_test.py",
+    tags = [
+        "multi_and_single_gpu",
+    ],
+    deps = [
+        "//tensorflow/python/distribute:combinations",
+        "//tensorflow/python/distribute:strategy_combinations",
+        "//tensorflow/python/eager:test",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+distribute_py_test(
     name = "keras_rnn_model_correctness_test",
     size = "medium",
     srcs = ["keras_rnn_model_correctness_test.py"],
@@ -743,6 +757,7 @@
     tags = [
         "noasan",  # TODO(b/156029134)
         "nomsan",  # TODO(b/156029134)
+        "notap",  # TODO(b/165865820): restore when not flaky
         "notsan",  # TODO(b/156029134)
     ],
     deps = [
@@ -821,7 +836,9 @@
     srcs = ["parameter_server_training_test.py"],
     python_version = "PY3",
     shard_count = 1,
-    tags = ["no_oss"],  # TODO(b/162119374): enable it in OSS.
+    tags = [
+        "no_oss",  # TODO(b/162119374): enable it in OSS.
+    ],
     deps = [
         "//tensorflow/python:constant_op",
         "//tensorflow/python:dtypes",
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
index fe55712..b6b9239 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py
@@ -52,15 +52,17 @@
     return x
 
 
+@combinations.generate(
+    combinations.combine(
+        distribution=(strategy_combinations.all_strategies +
+                      strategy_combinations.multiworker_strategies),
+        mode=["eager"]
+        )
+    )
 class KerasModelsTest(test.TestCase, parameterized.TestCase):
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_single_keras_layer_experimental_run(self, distribution):
-    dataset = self._get_dataset()
+  def test_single_keras_layer_run(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -72,7 +74,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         return grads
 
@@ -83,72 +85,33 @@
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_model_creation_experimental_run(self, distribution):
-    dataset = self._get_dataset()
+  def test_keras_model_optimizer_run(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
-      model = self._get_model()
-
-    @def_function.function
-    def train_step(iterator):
-      def step_fn(inputs):
-        images, targets = inputs
-        with backprop.GradientTape() as tape:
-          outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
-        grads = tape.gradient(loss, model.variables)
-        return grads
-
-      outputs = distribution.run(
-          step_fn, args=(next(iterator),))
-      return nest.map_structure(distribution.experimental_local_results,
-                                outputs)
-
-    train_step(input_iterator)
-
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_model_optimizer_experimental_run(self, distribution):
-    dataset = self._get_dataset()
-    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
-
-    with distribution.scope():
-      model = self._get_model()
+      model = _get_model()
       optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 
     @def_function.function
-    def train_step(iterator):
+    def train_step(replicated_inputs):
       def step_fn(inputs):
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
 
-      outputs = distribution.run(
-          step_fn, args=(next(iterator),))
+      outputs = distribution.run(step_fn, args=(replicated_inputs,))
       return nest.map_structure(distribution.experimental_local_results,
                                 outputs)
 
-    train_step(input_iterator)
+    for x in input_iterator:
+      train_step(x)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_subclass_model_optimizer_experimental_run(self, distribution):
+  def test_keras_subclass_model_optimizer_run(self, distribution):
     def get_subclass_model():
 
       class KerasSubclassModel(keras.Model):
@@ -161,7 +124,7 @@
           return self.l(x)
 
       return KerasSubclassModel()
-    dataset = self._get_dataset()
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -174,29 +137,23 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
 
-      outputs = distribution.run(
-          step_fn, args=(next(iterator),))
+      outputs = distribution.run(step_fn, args=(next(iterator),))
       return nest.map_structure(distribution.experimental_local_results,
                                 outputs)
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_keras_model_optimizer_experimental_run_loop(self, distribution):
-    dataset = self._get_dataset()
+  def test_keras_model_optimizer_run_loop(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
-      model = self._get_model()
+      model = _get_model()
       optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 
     @def_function.function
@@ -205,27 +162,22 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
 
-      for _ in range(5):
+      for _ in math_ops.range(4):
         distribution.run(step_fn, args=(next(iterator),))
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
   def test_batch_norm_with_dynamic_batch(self, distribution):
     inputs = np.zeros((10, 3, 3, 3), dtype=np.float32)
     targets = np.zeros((10, 4), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
     dataset = dataset.repeat()
-    dataset = dataset.batch(10, drop_remainder=False)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -242,7 +194,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images, training=True)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
         return loss
@@ -251,39 +203,13 @@
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]))
-  def test_model_predict_with_dynamic_batch(self, distribution):
-    input_data = np.random.random([1, 32, 64, 64, 3])
-    input_shape = tuple(input_data.shape[1:])
-
-    def build_model():
-      model = keras.models.Sequential()
-      model.add(
-          keras.layers.ConvLSTM2D(
-              4,
-              kernel_size=(4, 4),
-              activation="sigmoid",
-              padding="same",
-              input_shape=input_shape))
-      model.add(keras.layers.GlobalMaxPooling2D())
-      model.add(keras.layers.Dense(2, activation="sigmoid"))
-      return model
-
-    with distribution.scope():
-      model = build_model()
-      model.compile(loss="binary_crossentropy", optimizer="adam")
-      result = model.predict(input_data)
-      self.assertEqual(result.shape, (1, 2))
-
+  # TODO(b/165912857): Re-enable.
   @combinations.generate(
       combinations.combine(
           distribution=strategy_combinations.all_strategies,
           mode=["eager"]
       ))
-  def test_lstm(self, distribution):
+  def DISABLED_test_lstm(self, distribution):
 
     batch_size = 32
 
@@ -331,9 +257,6 @@
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies, mode=["eager"]))
   def test_nested_tf_functions(self, distribution):
     # The test builds two computations with keras layers, one with nested
     # tf.function, and the other without nested tf.function. We run these
@@ -343,7 +266,7 @@
     inputs = np.random.random((10, 3)).astype(np.float32)
     targets = np.ones((10, 4), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
-    dataset = dataset.batch(10, drop_remainder=True)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     def get_model():
@@ -366,7 +289,7 @@
 
     def compute_loss(images, targets):
       outputs = model(images)
-      return math_ops.reduce_sum(outputs - targets)
+      return keras.losses.mean_squared_error(targets, outputs)
 
     @def_function.function
     def train_step_without_nested_tf_function(inputs):
@@ -383,7 +306,7 @@
     @def_function.function
     def compute_loss2(images, targets):
       outputs = model2(images)
-      return math_ops.reduce_sum(outputs - targets)
+      return keras.losses.mean_squared_error(targets, outputs)
 
     @def_function.function
     def train_step_with_nested_tf_function(inputs):
@@ -406,14 +329,11 @@
     for model_v, model2_v in zip(model.variables, model2.variables):
       self.assertAllClose(model_v.numpy(), model2_v.numpy())
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies, mode=["eager"]))
   def test_nested_tf_functions_with_control_flow(self, distribution):
     inputs = np.random.random((10, 3)).astype(np.float32)
     targets = np.ones((10, 4), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
-    dataset = dataset.batch(10, drop_remainder=True)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     def get_model():
@@ -433,7 +353,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables))
 
@@ -446,13 +366,8 @@
 
     train_steps(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies,
-          mode=["eager"]
-      ))
-  def test_customized_tf_module_experimental_run(self, distribution):
-    dataset = self._get_dataset()
+  def test_customized_tf_module_run(self, distribution):
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -465,7 +380,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         return grads
 
@@ -476,14 +391,11 @@
 
     train_step(input_iterator)
 
-  @combinations.generate(
-      combinations.combine(
-          distribution=strategy_combinations.all_strategies, mode=["eager"]))
   def test_reduce_loss(self, distribution):
     inputs = np.zeros((10, 4), dtype=np.float32)
     targets = np.zeros((10, 1), dtype=np.float32)
     dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
-    dataset = dataset.batch(10, drop_remainder=False)
+    dataset = dataset.batch(10)
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     with distribution.scope():
@@ -505,11 +417,14 @@
     loss = train_step(input_iterator)
     loss = distribution.reduce(reduce_util.ReduceOp.MEAN, loss, axis=0)
 
+
+class KerasModelsXLATest(test.TestCase, parameterized.TestCase):
+
   @combinations.generate(
       combinations.combine(
           distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
   def test_tf_function_experimental_compile(self, distribution):
-    dataset = self._get_dataset()
+    dataset = _get_dataset()
     input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 
     class CustomDense(keras.layers.Layer):
@@ -537,7 +452,7 @@
         images, targets = inputs
         with backprop.GradientTape() as tape:
           outputs = model(images)
-          loss = math_ops.reduce_sum(outputs - targets)
+          loss = keras.losses.mean_squared_error(targets, outputs)
         grads = tape.gradient(loss, model.variables)
         return grads
 
@@ -548,20 +463,21 @@
 
     train_step(input_iterator)
 
-  def _get_dataset(self):
-    inputs = np.zeros((10, 3), dtype=np.float32)
-    targets = np.zeros((10, 4), dtype=np.float32)
-    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
-    dataset = dataset.repeat(100)
-    dataset = dataset.batch(10, drop_remainder=True)
-    return dataset
 
-  def _get_model(self):
-    x = keras.layers.Input(shape=(3,), name="input")
-    y = keras.layers.Dense(4, name="dense")(x)
-    model = keras.Model(x, y)
-    return model
+def _get_dataset():
+  inputs = np.zeros((31, 3), dtype=np.float32)
+  targets = np.zeros((31, 4), dtype=np.float32)
+  dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+  dataset = dataset.batch(10)
+  return dataset
+
+
+def _get_model():
+  x = keras.layers.Input(shape=(3,), name="input")
+  y = keras.layers.Dense(4, name="dense")(x)
+  model = keras.Model(x, y)
+  return model
 
 
 if __name__ == "__main__":
-  test.main()
+  combinations.main()
diff --git a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
index b9eee26..b014c88 100644
--- a/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
+++ b/tensorflow/python/keras/distribute/custom_training_loop_optimizer_test.py
@@ -56,8 +56,8 @@
     @def_function.function
     def optimize():
       grads = values.PerReplica([
-          ops.convert_to_tensor([1., 1.]),
-          ops.convert_to_tensor([2., 2.]),
+          ops.convert_to_tensor_v2_with_dispatch([1., 1.]),
+          ops.convert_to_tensor_v2_with_dispatch([2., 2.]),
       ])
 
       def step_fn(grads):
@@ -85,7 +85,7 @@
 
     @def_function.function
     def optimize():
-      grads = ops.convert_to_tensor([1., 1.])
+      grads = ops.convert_to_tensor_v2_with_dispatch([1., 1.])
 
       def step_fn(grads):
         optimizer.apply_gradients(
@@ -107,7 +107,7 @@
       v = variables.Variable([0., 0.])
       optimizer = gradient_descent.SGD(0.1)
 
-    grads = ops.convert_to_tensor([1., 1.])
+    grads = ops.convert_to_tensor_v2_with_dispatch([1., 1.])
 
     def step_fn(grads):
       with self.assertRaises(NotImplementedError):
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index 4ea5342..2329f51 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -1605,6 +1605,8 @@
       self.assertEqual(-1.0, v)
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class TestDistributionStrategyWithKerasModels(test.TestCase,
                                               parameterized.TestCase):
 
@@ -1762,7 +1764,7 @@
       outputs = keras.layers.Dense(1)(x)
       model = keras.Model(inputs, outputs)
 
-    model.compile('sgd', 'mse', experimental_steps_per_execution=10)
+    model.compile('sgd', 'mse', steps_per_execution=10)
 
     bc = BatchCountingCB()
     x, y = np.ones((100, 10, 10, 3)), np.ones((100, 1))
@@ -1786,7 +1788,7 @@
       outputs = keras.layers.Dense(1)(inputs)
       model = keras.Model(inputs, outputs)
 
-    model.compile('sgd', 'mse', experimental_steps_per_execution=20)
+    model.compile('sgd', 'mse', steps_per_execution=20)
 
     bc = BatchCountingCB()
     x, y = np.ones((100, 10)), np.ones((100, 1))
@@ -1810,7 +1812,7 @@
       outputs = keras.layers.Dense(1)(inputs)
       model = keras.Model(inputs, outputs)
 
-    model.compile('sgd', 'mse', experimental_steps_per_execution=20)
+    model.compile('sgd', 'mse', steps_per_execution=20)
 
     x, y = np.ones((100, 10)), np.ones((100, 1))
     ds = dataset_ops.DatasetV2.from_tensor_slices((x, y)).batch(2)
@@ -1846,7 +1848,7 @@
       outputs = keras.layers.Dense(1)(inputs)
       model = keras.Model(inputs, outputs)
 
-    model.compile('sgd', 'mse', experimental_steps_per_execution=500)
+    model.compile('sgd', 'mse', steps_per_execution=500)
 
     x, y = np.ones((100, 10)), np.ones((100, 1))
     bc = BatchCountingCB()
diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
index 6ec7cc2..e04b40e 100644
--- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py
@@ -24,6 +24,7 @@
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.eager import context
 from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import keras_correctness_test_base
 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
 from tensorflow.python.platform import test
@@ -47,6 +48,8 @@
     return not distribution_strategy_context.has_strategy()
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class TestDistributionStrategyDnnCorrectness(
     keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
 
@@ -240,6 +243,8 @@
     return self.dense4(x)
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
     TestDistributionStrategyDnnCorrectness):
 
diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
index 7e6ae3c..57b9b71 100644
--- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py
@@ -21,11 +21,15 @@
 from tensorflow.python import keras
 from tensorflow.python.distribute import combinations
 from tensorflow.python.eager import context
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import keras_correctness_test_base
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.platform import test
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul. Even if Dense layers run in '
+    'float64, the test sometimes fails with tf32 enabled for unknown reasons')
 class DistributionStrategyCnnCorrectnessTest(
     keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
 
diff --git a/tensorflow/python/keras/distribute/keras_models_test.py b/tensorflow/python/keras/distribute/keras_models_test.py
new file mode 100644
index 0000000..da58c04
--- /dev/null
+++ b/tensorflow/python/keras/distribute/keras_models_test.py
@@ -0,0 +1,60 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras high level APIs, e.g. fit, evaluate and predict."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.eager import test
+
+
+class KerasModelsTest(test.TestCase, parameterized.TestCase):
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=strategy_combinations.all_strategies, mode=["eager"]))
+  def test_lstm_model_with_dynamic_batch(self, distribution):
+    input_data = np.random.random([1, 32, 64, 64, 3])
+    input_shape = tuple(input_data.shape[1:])
+
+    def build_model():
+      model = keras.models.Sequential()
+      model.add(
+          keras.layers.ConvLSTM2D(
+              4,
+              kernel_size=(4, 4),
+              activation="sigmoid",
+              padding="same",
+              input_shape=input_shape))
+      model.add(keras.layers.GlobalMaxPooling2D())
+      model.add(keras.layers.Dense(2, activation="sigmoid"))
+      return model
+
+    with distribution.scope():
+      model = build_model()
+      model.compile(loss="binary_crossentropy", optimizer="adam")
+      result = model.predict(input_data)
+      self.assertEqual(result.shape, (1, 2))
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
index aa7f0c2..4e82b7d 100644
--- a/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
+++ b/tensorflow/python/keras/distribute/keras_rnn_model_correctness_test.py
@@ -69,6 +69,8 @@
     return model
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class DistributionStrategyGruModelCorrectnessTest(
     _DistributionStrategyRnnModelCorrectnessTest):
 
@@ -88,6 +90,8 @@
     self.run_correctness_test(distribution, use_numpy, use_validation_data)
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class DistributionStrategyLstmModelCorrectnessTest(
     _DistributionStrategyRnnModelCorrectnessTest):
 
diff --git a/tensorflow/python/keras/distribute/keras_save_load_test.py b/tensorflow/python/keras/distribute/keras_save_load_test.py
index 65877a0..fc2e2bd 100644
--- a/tensorflow/python/keras/distribute/keras_save_load_test.py
+++ b/tensorflow/python/keras/distribute/keras_save_load_test.py
@@ -20,10 +20,13 @@
 
 from tensorflow.python.distribute import combinations
 from tensorflow.python.eager import test
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import saved_model_test_base as test_base
 from tensorflow.python.keras.saving import save
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class KerasSaveLoadTest(test_base.TestSavedModelBase):
 
   def setUp(self):
diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py
index abed53b..0b98e85 100644
--- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py
+++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py
@@ -37,9 +37,10 @@
 def checkpoint_exists(filepath):
   """Returns whether the checkpoint `filepath` refers to exists."""
   if filepath.endswith('.h5'):
-    return file_io.file_exists(filepath)
-  tf_saved_model_exists = file_io.file_exists(filepath)
-  tf_weights_only_checkpoint_exists = file_io.file_exists(filepath + '.index')
+    return file_io.file_exists_v2(filepath)
+  tf_saved_model_exists = file_io.file_exists_v2(filepath)
+  tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
+      filepath + '.index')
   return tf_saved_model_exists or tf_weights_only_checkpoint_exists
 
 
@@ -145,7 +146,7 @@
       num_epoch = 2
 
       # The saving_filepath shouldn't exist at the beginning (as it's unique).
-      test_obj.assertFalse(file_io.file_exists(saving_filepath))
+      test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
 
       model.fit(
           x=train_ds,
@@ -153,7 +154,7 @@
           steps_per_epoch=steps,
           callbacks=[callbacks.ModelCheckpoint(filepath=saving_filepath)])
 
-      test_obj.assertTrue(file_io.file_exists(saving_filepath))
+      test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
 
     saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')
 
@@ -185,7 +186,7 @@
       num_epoch = 4
 
       # The saving_filepath shouldn't exist at the beginning (as it's unique).
-      test_obj.assertFalse(file_io.file_exists(saving_filepath))
+      test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
       bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup')
 
       try:
@@ -204,8 +205,8 @@
 
       multi_process_runner.barrier().wait()
       backup_filepath = os.path.join(bar_dir, 'checkpoint')
-      test_obj.assertTrue(file_io.file_exists(backup_filepath))
-      test_obj.assertTrue(file_io.file_exists(saving_filepath))
+      test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
+      test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
 
       model.fit(
           x=train_ds,
@@ -217,8 +218,8 @@
               AssertCallback()
           ])
       multi_process_runner.barrier().wait()
-      test_obj.assertFalse(file_io.file_exists(backup_filepath))
-      test_obj.assertTrue(file_io.file_exists(saving_filepath))
+      test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
+      test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
 
     saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint')
 
@@ -244,7 +245,7 @@
           'logfile_%s_%d' % (task_config['type'], task_config['index']))
 
       # The saving_filepath shouldn't exist at the beginning (as it's unique).
-      test_obj.assertFalse(file_io.file_exists(saving_filepath))
+      test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
 
       model.fit(
           x=train_ds,
@@ -257,7 +258,8 @@
       # `file_io.list_directory()` since the directory may be created at this
       # point.
       test_obj.assertEqual(
-          bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
+          bool(file_io.list_directory_v2(saving_filepath)),
+          test_base.is_chief())
 
     multi_process_runner.run(
         proc_tensorboard_saves_on_chief_but_not_otherwise,
@@ -280,7 +282,7 @@
 
       # Verifies that even if `saving_filepath_for_temp` exists, tensorboard
       # can still save to temporary directory.
-      test_obj.assertTrue(file_io.file_exists(saving_filepath_for_temp))
+      test_obj.assertTrue(file_io.file_exists_v2(saving_filepath_for_temp))
 
       model.fit(
           x=train_ds,
@@ -301,7 +303,7 @@
       num_epoch = 2
 
       # The saving_filepath shouldn't exist at the beginning (as it's unique).
-      test_obj.assertFalse(file_io.file_exists(saving_filepath))
+      test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
 
       multi_process_runner.barrier().wait()
 
@@ -313,7 +315,7 @@
 
       multi_process_runner.barrier().wait()
 
-      test_obj.assertTrue(file_io.list_directory(saving_filepath))
+      test_obj.assertTrue(file_io.list_directory_v2(saving_filepath))
 
     saving_filepath = os.path.join(self.get_temp_dir(), 'logfile')
 
diff --git a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py b/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py
index f7d64c2..ee2f64d 100644
--- a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py
+++ b/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py
@@ -159,9 +159,10 @@
       # Make sure chief finishes saving before non-chief's assertions.
       multi_process_runner.barrier().wait()
 
-      if not file_io.file_exists(model_path):
+      if not file_io.file_exists_v2(model_path):
         raise RuntimeError()
-      if file_io.file_exists(write_model_path) != _is_chief(task_type, task_id):
+      if file_io.file_exists_v2(write_model_path) != _is_chief(
+          task_type, task_id):
         raise RuntimeError()
 
       loaded_model = keras.saving.save.load_model(model_path)
@@ -179,9 +180,9 @@
       # Make sure chief finishes saving before non-chief's assertions.
       multi_process_runner.barrier().wait()
 
-      if not file_io.file_exists(checkpoint_dir):
+      if not file_io.file_exists_v2(checkpoint_dir):
         raise RuntimeError()
-      if file_io.file_exists(write_checkpoint_dir) != _is_chief(
+      if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief(
           task_type, task_id):
         raise RuntimeError()
 
diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py
index 70547ff..12a7db4 100644
--- a/tensorflow/python/keras/distribute/parameter_server_training_test.py
+++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py
@@ -69,7 +69,7 @@
     ]
     label_vocab = ["yes", "no"]
 
-    with self.client.context():
+    with self.client.strategy.scope():
 
       # Define KPLs under client's context. Right now, if they have look up
       # tables, they will be created on the client. Their variables will be
@@ -167,7 +167,7 @@
     for _ in range(10):
       self.client.schedule(worker_fn, args=(distributed_iterator,))
     self.client.join()
-    self.assertGreaterEqual(accuracy.result().numpy(), 0.5)
+    self.assertGreater(accuracy.result().numpy(), 0.0)
 
     # Create a saved model.
     model.feature_ps = feature_ps
diff --git a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py
index d303a42..7815d74 100644
--- a/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py
+++ b/tensorflow/python/keras/distribute/saved_model_mixed_api_test.py
@@ -26,12 +26,15 @@
 
 from tensorflow.python.distribute import combinations
 from tensorflow.python.eager import test
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import saved_model_test_base as test_base
 from tensorflow.python.keras.saving import save
 
 _DEFAULT_FUNCTION_KEY = 'serving_default'
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
 
   def setUp(self):
diff --git a/tensorflow/python/keras/distribute/saved_model_save_load_test.py b/tensorflow/python/keras/distribute/saved_model_save_load_test.py
index 39856af..2174d39 100644
--- a/tensorflow/python/keras/distribute/saved_model_save_load_test.py
+++ b/tensorflow/python/keras/distribute/saved_model_save_load_test.py
@@ -24,6 +24,7 @@
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import test
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import model_combinations
 from tensorflow.python.keras.distribute import saved_model_test_base as test_base
 from tensorflow.python.ops import array_ops
@@ -32,6 +33,8 @@
 from tensorflow.python.saved_model import saved_model
 
 
+@testing_utils.run_all_without_tensor_float_32(
+    'Uses Dense layers, which call matmul')
 class SavedModelKerasModelTest(test_base.TestSavedModelBase):
 
   def setUp(self):
diff --git a/tensorflow/python/keras/distribute/worker_training_state.py b/tensorflow/python/keras/distribute/worker_training_state.py
index 29939ed..6385594 100644
--- a/tensorflow/python/keras/distribute/worker_training_state.py
+++ b/tensorflow/python/keras/distribute/worker_training_state.py
@@ -112,12 +112,12 @@
     successfully finishes.
     """
     # pylint: disable=protected-access
-    for pathname in file_io.get_matching_files(
+    for pathname in file_io.get_matching_files_v2(
         self.write_checkpoint_manager._prefix + '*'):
-      file_io.delete_recursively(pathname)
-    for pathname in file_io.get_matching_files(
+      file_io.delete_recursively_v2(pathname)
+    for pathname in file_io.get_matching_files_v2(
         os.path.join(self.write_checkpoint_manager.directory, 'checkpoint')):
-      file_io.delete_recursively(pathname)
+      file_io.delete_recursively_v2(pathname)
 
   def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
     """Maybe load initial epoch from ckpt considering possible worker recovery.
diff --git a/tensorflow/python/keras/distribute/worker_training_state_test.py b/tensorflow/python/keras/distribute/worker_training_state_test.py
index 80a3dea..0411aed 100644
--- a/tensorflow/python/keras/distribute/worker_training_state_test.py
+++ b/tensorflow/python/keras/distribute/worker_training_state_test.py
@@ -48,7 +48,7 @@
         callbacks.ModelCheckpoint(
             filepath=saving_filepath, save_weights_only=save_weights_only)
     ]
-    self.assertFalse(file_io.file_exists(saving_filepath))
+    self.assertFalse(file_io.file_exists_v2(saving_filepath))
 
     try:
       model.fit(
@@ -56,9 +56,9 @@
     except NotFoundError as e:
       if 'Failed to create a NewWriteableFile' in e.message:
         self.skipTest('b/138941852, path not found error in Windows py35.')
-    tf_saved_model_exists = file_io.file_exists(saving_filepath)
-    tf_weights_only_checkpoint_exists = file_io.file_exists(saving_filepath +
-                                                            '.index')
+    tf_saved_model_exists = file_io.file_exists_v2(saving_filepath)
+    tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
+        saving_filepath + '.index')
     self.assertTrue(tf_saved_model_exists or tf_weights_only_checkpoint_exists)
 
 
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 5ac0a6d..dbe83be 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -1006,10 +1006,10 @@
         np_arrays.ndarray, np.ndarray, float, int)) for x in input_list):
 
       def _convert_non_tensor(x):
-        # Don't call `ops.convert_to_tensor_v2` on all `inputs` because
+        # Don't call `ops.convert_to_tensor` on all `inputs` because
         # `SparseTensors` can't be converted to `Tensor`.
         if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
-          return ops.convert_to_tensor_v2(x)
+          return ops.convert_to_tensor_v2_with_dispatch(x)
         return x
 
       inputs = nest.map_structure(_convert_non_tensor, inputs)
@@ -1518,7 +1518,8 @@
       if loss is None:
         return None  # Will be filtered out when computing the .losses property
       if not tensor_util.is_tensor(loss):
-        loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx())
+        loss = ops.convert_to_tensor_v2_with_dispatch(
+            loss, dtype=backend.floatx())
       loss._unconditional_loss = True  # pylint: disable=protected-access
       return loss
 
@@ -1535,7 +1536,8 @@
         continue
       if not tensor_util.is_tensor(loss) and not isinstance(
           loss, keras_tensor.KerasTensor):
-        loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx())
+        loss = ops.convert_to_tensor_v2_with_dispatch(
+            loss, dtype=backend.floatx())
       # TF Functions should take the eager path.
       if ((tf_utils.is_symbolic_tensor(loss) or
            isinstance(loss, keras_tensor.KerasTensor)) and
@@ -1757,7 +1759,7 @@
     if not call_context.frozen:
       for update in nest.flatten(updates):
         if callable(update):
-          update()
+          update()  # pylint: disable=not-callable
 
   def set_weights(self, weights):
     """Sets the weights of the layer, from Numpy arrays.
@@ -2586,10 +2588,10 @@
     # we copy them to avoid loss of KerasHistory metadata.
     flat_outputs = nest.flatten(outputs)
     flat_inputs = nest.flatten((args, kwargs))
-    inputs_set = object_identity.ObjectIdentitySet(flat_inputs)
+    input_ids_set = {id(i) for i in flat_inputs}
     outputs_copy = []
     for x in flat_outputs:
-      if x in inputs_set:
+      if id(x) in input_ids_set:
         with backend.name_scope(self.name):
           x = array_ops.identity(x)
       outputs_copy.append(x)
@@ -2985,12 +2987,13 @@
 
   def _dedup_weights(self, weights):
     """Dedupe weights while maintaining order as much as possible."""
-    output, seen_weights = [], object_identity.ObjectIdentitySet()
+    output, seen_ids = [], set()
     for w in weights:
-      if w not in seen_weights:
+      if id(w) not in seen_ids:
         output.append(w)
         # Track the Variable's identity to avoid __eq__ issues.
-        seen_weights.add(w)
+        seen_ids.add(id(w))
+
     return output
 
   def _split_out_first_arg(self, args, kwargs):
@@ -3266,7 +3269,7 @@
 
 def _convert_numpy_or_python_types(x):
   if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)):
-    return ops.convert_to_tensor_v2(x)
+    return ops.convert_to_tensor_v2_with_dispatch(x)
   return x
 
 
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index 022718e..c377e3e 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -1135,7 +1135,7 @@
     self.assertEqual(sublayer.active_name_scope, 'MyName2/Sublayer')
 
   def test_name_scope_tf_tensor(self):
-    x = ops.convert_to_tensor_v2(np.ones((10, 10)))
+    x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10)))
     layer = layers.Dense(
         10, activation=layers.ReLU(name='MyAct'), name='MyName3')
     layer(x)
diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py
index 536efb5..a614152 100644
--- a/tensorflow/python/keras/engine/base_layer_v1.py
+++ b/tensorflow/python/keras/engine/base_layer_v1.py
@@ -690,10 +690,10 @@
     # Accept NumPy and scalar inputs by converting to Tensors.
     if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
       def _convert_non_tensor(x):
-        # Don't call `ops.convert_to_tensor_v2` on all `inputs` because
+        # Don't call `ops.convert_to_tensor` on all `inputs` because
         # `SparseTensors` can't be converted to `Tensor`.
         if isinstance(x, (np.ndarray, float, int)):
-          return ops.convert_to_tensor_v2(x)
+          return ops.convert_to_tensor_v2_with_dispatch(x)
         return x
       inputs = nest.map_structure(_convert_non_tensor, inputs)
       input_list = nest.flatten(inputs)
@@ -1053,7 +1053,8 @@
       if loss is None:
         return None  # Will be filtered out when computing the .losses property
       if not tensor_util.is_tensor(loss):
-        loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx())
+        loss = ops.convert_to_tensor_v2_with_dispatch(
+            loss, dtype=backend.floatx())
       loss._unconditional_loss = (inputs is None)  # pylint: disable=protected-access
       return loss
 
@@ -1068,7 +1069,8 @@
       if loss is None:
         continue
       if not tensor_util.is_tensor(loss):
-        loss = ops.convert_to_tensor_v2(loss, dtype=backend.floatx())
+        loss = ops.convert_to_tensor_v2_with_dispatch(
+            loss, dtype=backend.floatx())
       # TF Functions should take the eager path.
       if (tf_utils.is_symbolic_tensor(loss) and
           not base_layer_utils.is_in_tf_function()):
@@ -1229,7 +1231,7 @@
       elif hasattr(x, 'op'):
         update = x.op
       else:
-        update = ops.convert_to_tensor_v2(x)
+        update = ops.convert_to_tensor_v2_with_dispatch(x)
 
       reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update])
       update._unconditional_update = update not in reachable
@@ -2378,12 +2380,13 @@
 
   def _dedup_weights(self, weights):
     """Dedupe weights while maintaining order as much as possible."""
-    output, seen_weights = [], object_identity.ObjectIdentitySet()
+    output, seen_ids = [], set()
     for w in weights:
-      if w not in seen_weights:
+      if id(w) not in seen_ids:
         output.append(w)
         # Track the Variable's identity to avoid __eq__ issues.
-        seen_weights.add(w)
+        seen_ids.add(id(w))
+
     return output
 
   # SavedModel properties. Please see keras/saving/saved_model for details.
diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py
index f5577bf..a02a329 100644
--- a/tensorflow/python/keras/engine/base_preprocessing_layer.py
+++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py
@@ -149,7 +149,7 @@
     else:
       accumulator = self._combiner.restore(self._restore_updates())
     if isinstance(data, (list, tuple)):
-      data = ops.convert_to_tensor_v2(data)
+      data = ops.convert_to_tensor_v2_with_dispatch(data)
     if not isinstance(data,
                       (dataset_ops.DatasetV2,
                        np.ndarray,
diff --git a/tensorflow/python/keras/engine/compile_utils_test.py b/tensorflow/python/keras/engine/compile_utils_test.py
index 3912727..ae92b9a 100644
--- a/tensorflow/python/keras/engine/compile_utils_test.py
+++ b/tensorflow/python/keras/engine/compile_utils_test.py
@@ -53,7 +53,7 @@
 
     y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
     y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     total_loss = loss_container(y_t, y_p, sample_weight=sw)
 
@@ -86,7 +86,7 @@
 
     y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
     y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))}
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     total_loss = loss_container(y_t, y_p, sample_weight=sw)
 
@@ -112,7 +112,7 @@
 
     y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
     y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     total_loss = loss_container(y_t, y_p, sample_weight=sw)
 
@@ -135,7 +135,7 @@
 
     y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
     y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))}
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     total_loss = loss_container(y_t, y_p, sample_weight=sw)
 
@@ -170,7 +170,7 @@
               array_ops.zeros((10, 1))],
         'a': array_ops.ones((10, 1))
     }
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     total_loss = loss_container(y_t, y_p, sample_weight=sw)
     self.assertEqual(total_loss.numpy(), 0.75)
@@ -193,7 +193,7 @@
 
     y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
     y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     total_loss = loss_container(y_t, y_p, sample_weight=sw)
     self.assertEqual(total_loss.numpy(), 0.5)
@@ -220,13 +220,13 @@
     })
 
     y_p = {
-        'output1': ops.convert_to_tensor([[0], [1], [2]]),
-        'output2': ops.convert_to_tensor([[3], [4], [5]]),
-        'output3': ops.convert_to_tensor([[6], [7], [8]])
+        'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]),
+        'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]),
+        'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]])
     }
     y_t = {
-        'output1': ops.convert_to_tensor([[1], [2], [3]]),
-        'output3': ops.convert_to_tensor([[4], [5], [6]])
+        'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]),
+        'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]])
     }
 
     total_loss = loss_container(y_t, y_p)
@@ -372,7 +372,7 @@
 
     y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
     y_p = [array_ops.ones((10, 1)), 2 * array_ops.ones((10, 1))]
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
     metric_container.update_state(y_t, y_p, sample_weight=sw)
     self.assertLen(metric_container.metrics, 6)
 
@@ -415,7 +415,7 @@
 
     y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
     y_p = {'out1': array_ops.ones((10, 1)), 'out2': 2 * array_ops.ones((10, 1))}
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
     metric_container.update_state(y_t, y_p, sample_weight=sw)
 
     mse_metric = metric_container.metrics[0]
@@ -440,7 +440,7 @@
 
     y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
     y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     metric_container.update_state(y_t, y_p, sample_weight=sw)
     self.assertLen(metric_container.metrics, 1)
@@ -457,7 +457,7 @@
 
     y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
     y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))}
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     metric_container.update_state(y_t, y_p, sample_weight=sw)
     self.assertLen(metric_container.metrics, 1)
@@ -487,7 +487,7 @@
               array_ops.zeros((10, 1))],
         'a': array_ops.ones((10, 1))
     }
-    sw = ops.convert_to_tensor_v2([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
+    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
 
     metric_container.update_state(y_t, y_p, sample_weight=sw)
     self.assertLen(metric_container.metrics, 3)
@@ -548,9 +548,9 @@
     metric_container = compile_utils.MetricsContainer(
         metrics=['mae'], weighted_metrics=['mae'])
 
-    y_t = ops.convert_to_tensor_v2([[0], [3], [0]])
-    y_p = ops.convert_to_tensor_v2([[0], [0], [0]])
-    sw = ops.convert_to_tensor_v2([[1], [0], [1]])
+    y_t = ops.convert_to_tensor_v2_with_dispatch([[0], [3], [0]])
+    y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [0], [0]])
+    sw = ops.convert_to_tensor_v2_with_dispatch([[1], [0], [1]])
 
     metric_container.update_state(y_t, y_p, sample_weight=sw)
     self.assertLen(metric_container.metrics, 2)
@@ -566,8 +566,8 @@
   def test_broadcast_metrics_to_dict(self):
     metric_container = compile_utils.MetricsContainer(metrics=['mae'])
 
-    y_p = {'output': ops.convert_to_tensor([[0], [1], [2]])}
-    y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])}
+    y_p = {'output': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]])}
+    y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])}
     metric_container.update_state(y_t, y_p)
 
     mae_metric = metric_container.metrics[0]
@@ -578,8 +578,8 @@
     metric_container = compile_utils.MetricsContainer(
         metrics=['mae'], output_names=['output'])
 
-    y_p = ops.convert_to_tensor([[0], [1], [2]])
-    y_t = {'output': ops.convert_to_tensor([[1], [2], [3]])}
+    y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]])
+    y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])}
     metric_container.update_state(y_t, y_p)
 
     mae_metric = metric_container.metrics[0]
@@ -595,13 +595,13 @@
     })
 
     y_p = {
-        'output1': ops.convert_to_tensor([[0], [1], [2]]),
-        'output2': ops.convert_to_tensor([[3], [4], [5]]),
-        'output3': ops.convert_to_tensor([[6], [7], [8]])
+        'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]),
+        'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]),
+        'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]])
     }
     y_t = {
-        'output1': ops.convert_to_tensor([[1], [2], [3]]),
-        'output3': ops.convert_to_tensor([[4], [5], [6]])
+        'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]),
+        'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]])
     }
 
     metric_container.update_state(y_t, y_p)
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index e9662da..0df15f3 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -1006,7 +1006,7 @@
       dtype = None
       if issubclass(x.dtype.type, np.floating):
         dtype = backend.floatx()
-      return ops.convert_to_tensor(x, dtype=dtype)
+      return ops.convert_to_tensor_v2_with_dispatch(x, dtype=dtype)
     elif scipy_sparse and scipy_sparse.issparse(x):
       return _scipy_sparse_to_sparse_tensor(x)
     return x
@@ -1281,7 +1281,7 @@
         "than the number of classes, found {}").format(class_weight)
     raise ValueError(error_msg)
 
-  class_weight_tensor = ops.convert_to_tensor_v2(
+  class_weight_tensor = ops.convert_to_tensor_v2_with_dispatch(
       [class_weight[int(c)] for c in class_ids])
 
   def _class_weights_map_fn(*data):
diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py
index fad1930..b17410c 100644
--- a/tensorflow/python/keras/engine/data_adapter_test.py
+++ b/tensorflow/python/keras/engine/data_adapter_test.py
@@ -446,7 +446,7 @@
   def test_training(self):
     # First verify that DummyArrayLike can't be converted to a Tensor
     with self.assertRaises(TypeError):
-      ops.convert_to_tensor_v2(self.arraylike_input)
+      ops.convert_to_tensor_v2_with_dispatch(self.arraylike_input)
 
     # Then train on the array like.
     # It should not be converted to a tensor directly (which would force it into
@@ -914,7 +914,7 @@
     def generator():
       for _ in range(2):
         for step in range(3):
-          yield (ops.convert_to_tensor_v2([step]),)
+          yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
 
     data_handler = data_adapter.DataHandler(
         generator(), epochs=2, steps_per_epoch=3)
@@ -1007,20 +1007,20 @@
       y = np.array([0, 2, 4, 6, 8])
       sw = np.array([0, 4, 8, 12, 16])
     else:
-      x = ops.convert_to_tensor_v2([0, 1, 2, 3, 4])
-      y = ops.convert_to_tensor_v2([0, 2, 4, 6, 8])
-      sw = ops.convert_to_tensor_v2([0, 4, 8, 12, 16])
+      x = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2, 3, 4])
+      y = ops.convert_to_tensor_v2_with_dispatch([0, 2, 4, 6, 8])
+      sw = ops.convert_to_tensor_v2_with_dispatch([0, 4, 8, 12, 16])
 
     (train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
         data_adapter.train_validation_split((x, y, sw), validation_split=0.2))
 
     if use_numpy:
-      train_x = ops.convert_to_tensor_v2(train_x)
-      train_y = ops.convert_to_tensor_v2(train_y)
-      train_sw = ops.convert_to_tensor_v2(train_sw)
-      val_x = ops.convert_to_tensor_v2(val_x)
-      val_y = ops.convert_to_tensor_v2(val_y)
-      val_sw = ops.convert_to_tensor_v2(val_sw)
+      train_x = ops.convert_to_tensor_v2_with_dispatch(train_x)
+      train_y = ops.convert_to_tensor_v2_with_dispatch(train_y)
+      train_sw = ops.convert_to_tensor_v2_with_dispatch(train_sw)
+      val_x = ops.convert_to_tensor_v2_with_dispatch(val_x)
+      val_y = ops.convert_to_tensor_v2_with_dispatch(val_y)
+      val_sw = ops.convert_to_tensor_v2_with_dispatch(val_sw)
 
     self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3])
     self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6])
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 9cb35ff..52d73ad 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -445,6 +445,7 @@
               loss_weights=None,
               weighted_metrics=None,
               run_eagerly=None,
+              steps_per_execution=None,
               **kwargs):
     """Configures the model for training.
 
@@ -496,17 +497,18 @@
           logic will not be wrapped in a `tf.function`. Recommended to leave
           this as `None` unless your `Model` cannot be run inside a
           `tf.function`.
-        **kwargs: Any additional arguments. Supported arguments:
-            - `experimental_steps_per_execution`: Int. The number of batches to
-              run during each `tf.function` call. Running multiple batches
-              inside a single `tf.function` call can greatly improve performance
-              on TPUs or small models with a large Python overhead. Note that if
-              this value is set to `N`, `Callback.on_batch` methods will only be
-              called every `N` batches. This currently defaults to `1`. At most,
-              one full epoch will be run each execution. If a number larger than
-              the size of the epoch is passed, the execution will be truncated
-              to the size of the epoch.
-            - `sample_weight_mode` for backward compatibility.
+        steps_per_execution: Int. Defaults to 1. The number of batches to
+          run during each `tf.function` call. Running multiple batches
+          inside a single `tf.function` call can greatly improve performance
+          on TPUs or small models with a large Python overhead.
+          At most, one full epoch will be run each
+          execution. If a number larger than the size of the epoch is passed,
+          the execution will be truncated to the size of the epoch.
+          Note that if `steps_per_execution` is set to `N`,
+          `Callback.on_batch_begin` and `Callback.on_batch_end` methods
+          will only be called every `N` batches
+          (i.e. before/after each `tf.function` execution).
+        **kwargs: Arguments supported for backwards compatibility only.
 
     Raises:
         ValueError: In case of invalid arguments for
@@ -514,6 +516,13 @@
     """
     base_layer.keras_api_gauge.get_cell('compile').set(True)
     with self.distribute_strategy.scope():
+      if 'experimental_steps_per_execution' in kwargs:
+        logging.warn('The argument `steps_per_execution` is no longer '
+                     'experimental. Pass `steps_per_execution` instead of '
+                     '`experimental_steps_per_execution`.')
+        if not steps_per_execution:
+          steps_per_execution = kwargs.pop('experimental_steps_per_execution')
+
       self._validate_compile(optimizer, metrics, **kwargs)
       self._run_eagerly = run_eagerly
 
@@ -523,9 +532,7 @@
       self.compiled_metrics = compile_utils.MetricsContainer(
           metrics, weighted_metrics, output_names=self.output_names)
 
-      experimental_steps_per_execution = kwargs.pop(
-          'experimental_steps_per_execution', 1)
-      self._configure_steps_per_execution(experimental_steps_per_execution)
+      self._configure_steps_per_execution(steps_per_execution or 1)
 
       # Initializes attrs that are reset each time `compile` is called.
       self._reset_compile_cache()
@@ -2460,9 +2467,7 @@
     if kwargs.pop('target_tensors', None) is not None:
       raise ValueError(
           'target_tensors argument is not supported when executing eagerly.')
-    invalid_kwargs = set(kwargs) - {
-        'experimental_steps_per_execution', 'sample_weight_mode'
-    }
+    invalid_kwargs = set(kwargs) - {'sample_weight_mode'}
     if invalid_kwargs:
       raise TypeError('Invalid keyword argument(s) in `compile`: %s' %
                       (invalid_kwargs,))
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index b3ce3d1..09e6f0d 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -121,7 +121,7 @@
   if any(
       isinstance(input_t, (np.ndarray, float, int))
       for input_t in nest.flatten(inputs)):
-    inputs = nest.map_structure(ops.convert_to_tensor_v2, inputs)
+    inputs = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, inputs)
 
   outs = model(inputs, **kwargs)
   outs = nest.flatten(outs)
@@ -131,7 +131,8 @@
   # TODO(sallymatson/psv): check if we should do same mismatch fix for weights
   if sample_weights:
     sample_weights = [
-        training_utils.cast_if_floating_dtype(ops.convert_to_tensor_v2(val))
+        training_utils.cast_if_floating_dtype(
+            ops.convert_to_tensor_v2_with_dispatch(val))
         if val is not None else None for val in sample_weights
     ]
 
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 84bcd99..7a8c1c1 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -1009,7 +1009,7 @@
       class_sample_weight = math_ops.cast(class_sample_weight, K.floatx())
       if sample_weight is not None:
         sample_weight = math_ops.cast(
-            ops.convert_to_tensor_v2(sample_weight), K.floatx())
+            ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx())
     else:
       y_classes = y
       if len(y.shape) == 2:
@@ -1365,7 +1365,7 @@
 
 def cast_single_tensor(x, dtype=None):
   if isinstance(x, np.ndarray):
-    x = ops.convert_to_tensor_v2(x)
+    x = ops.convert_to_tensor_v2_with_dispatch(x)
   dtype = dtype or K.floatx()
   if x.dtype.is_floating:
     return math_ops.cast(x, dtype=dtype)
@@ -1391,7 +1391,7 @@
   new_targets = []
   for target, out in zip(targets, outputs):
     if isinstance(target, np.ndarray):
-      target = ops.convert_to_tensor_v2(target)
+      target = ops.convert_to_tensor_v2_with_dispatch(target)
     if target.dtype != out.dtype:
       new_targets.append(cast_single_tensor(target, dtype=out.dtype))
     else:
diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD
index 6458d09..b5b876b 100644
--- a/tensorflow/python/keras/layers/BUILD
+++ b/tensorflow/python/keras/layers/BUILD
@@ -814,7 +814,6 @@
     python_version = "PY3",
     shard_count = 12,
     tags = [
-        "no_cuda11",
         "no_oss",
     ],
     xla_enable_strict_auto_jit = False,
diff --git a/tensorflow/python/keras/layers/convolutional_transpose_test.py b/tensorflow/python/keras/layers/convolutional_transpose_test.py
index dd73d22..4326044 100644
--- a/tensorflow/python/keras/layers/convolutional_transpose_test.py
+++ b/tensorflow/python/keras/layers/convolutional_transpose_test.py
@@ -207,3 +207,6 @@
             },
             input_shape=(None, 3, None, None, None),
             input_data=input_data)
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 36ac087..1ceedad 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -201,7 +201,7 @@
     noise_shape = []
     for i, value in enumerate(self.noise_shape):
       noise_shape.append(concrete_inputs_shape[i] if value is None else value)
-    return ops.convert_to_tensor_v2(noise_shape)
+    return ops.convert_to_tensor_v2_with_dispatch(noise_shape)
 
   def call(self, inputs, training=None):
     if training is None:
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index f650981..b7a11d3 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -504,14 +504,14 @@
         keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2))
 
   def test_dense_dtype(self):
-    inputs = ops.convert_to_tensor_v2(
+    inputs = ops.convert_to_tensor_v2_with_dispatch(
         np.random.randint(low=0, high=7, size=(2, 2)))
     layer = keras.layers.Dense(5, dtype='float32')
     outputs = layer(inputs)
     self.assertEqual(outputs.dtype, 'float32')
 
   def test_dense_with_policy(self):
-    inputs = ops.convert_to_tensor_v2(
+    inputs = ops.convert_to_tensor_v2_with_dispatch(
         np.random.randint(low=0, high=7, size=(2, 2)))
     layer = keras.layers.Dense(5, dtype=policy.Policy('mixed_float16'))
     outputs = layer(inputs)
diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py
index cd277a1..ab29125 100644
--- a/tensorflow/python/keras/layers/dense_attention.py
+++ b/tensorflow/python/keras/layers/dense_attention.py
@@ -180,7 +180,7 @@
       q_mask = mask[0]
       if q_mask is None:
         return None
-      return ops.convert_to_tensor_v2(q_mask)
+      return ops.convert_to_tensor_v2_with_dispatch(q_mask)
     return None
 
   def _validate_call_args(self, inputs, mask):
diff --git a/tensorflow/python/keras/layers/kernelized.py b/tensorflow/python/keras/layers/kernelized.py
index eac985e..c8a6a65 100644
--- a/tensorflow/python/keras/layers/kernelized.py
+++ b/tensorflow/python/keras/layers/kernelized.py
@@ -218,7 +218,7 @@
     super(RandomFourierFeatures, self).build(input_shape)
 
   def call(self, inputs):
-    inputs = ops.convert_to_tensor_v2(inputs, dtype=self.dtype)
+    inputs = ops.convert_to_tensor_v2_with_dispatch(inputs, dtype=self.dtype)
     inputs = gen_math_ops.cast(inputs, dtypes.float32)
     kernel = (1.0 / self.kernel_scale) * self.unscaled_kernel
     outputs = gen_math_ops.mat_mul(inputs, kernel)
diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py
index 1e33edd..416a26d 100644
--- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py
+++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py
@@ -282,7 +282,7 @@
   def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
     if inputs is not None:
       # Validate the given batch_size and dtype against inputs if provided.
-      inputs = ops.convert_to_tensor(inputs, name="inputs")
+      inputs = ops.convert_to_tensor_v2_with_dispatch(inputs, name="inputs")
       if batch_size is not None:
         if tensor_util.is_tensor(batch_size):
           static_batch_size = tensor_util.constant_value(
diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py
index 2e39239..9618bc7 100644
--- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py
+++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py
@@ -116,7 +116,7 @@
     with ops.name_scope_v2("DropoutWrapperInit"):
 
       def tensor_and_const_value(v):
-        tensor_value = ops.convert_to_tensor(v)
+        tensor_value = ops.convert_to_tensor_v2_with_dispatch(v)
         const_value = tensor_util.constant_value(tensor_value)
         return (tensor_value, const_value)
 
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index d9bac2c..92178bc 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -41,26 +41,43 @@
 
 
 class BatchNormalizationBase(Layer):
-  r"""Normalize and scale inputs or activations.
+  r"""Layer that normalizes its inputs.
 
-  Normalize the activations of the previous layer at each batch,
-  i.e. applies a transformation that maintains the mean activation
-  close to 0 and the activation standard deviation close to 1.
+  Batch normalization applies a transformation that maintains the mean output
+  close to 0 and the output standard deviation close to 1.
 
-  Batch normalization differs from other layers in several key aspects:
+  Importantly, batch normalization works differently during training and
+  during inference.
 
-  1) Adding BatchNormalization with `training=True` to a model causes the
-  result of one example to depend on the contents of all other examples in a
-  minibatch. Be careful when padding batches or masking examples, as these can
-  change the minibatch statistics and affect other examples.
+  **During training** (i.e. when using `fit()` or when calling the layer/model
+  with the argument `training=True`), the layer normalizes its output using
+  the mean and standard deviation of the current batch of inputs. That is to
+  say, for each channel being normalized, the layer returns
+  `(batch - mean(batch)) / (var(batch) + epsilon) * gamma + beta`, where:
 
-  2) Updates to the weights (moving statistics) are based on the forward pass
-  of a model rather than the result of gradient computations.
+  - `epsilon` is small constant (configurable as part of the constructor
+  arguments)
+  - `gamma` is a learned scaling factor (initialized as 1), which
+  can be disabled by passing `scale=False` to the constructor.
+  - `beta` is a learned offset factor (initialized as 0), which
+  can be disabled by passing `center=False` to the constructor.
 
-  3) When performing inference using a model containing batch normalization, it
-  is generally (though not always) desirable to use accumulated statistics
-  rather than mini-batch statistics. This is accomplished by passing
-  `training=False` when calling the model, or using `model.predict`.
+  **During inference** (i.e. when using `evaluate()` or `predict()` or when
+  calling the layer/model with the argument `training=False` (which is the
+  default), the layer normalizes its output using a moving average of the
+  mean and standard deviation of the batches it has seen during training. That
+  is to say, it returns
+  `(batch - self.moving_mean) / (self.moving_var + epsilon) * gamma + beta`.
+
+  `self.moving_mean` and `self.moving_var` are non-trainable variables that
+  are updated each time the layer in called in training mode, as such:
+
+  - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
+  - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
+
+  As such, the layer will only normalize its inputs during inference
+  *after having been trained on data that has similar statistics as the
+  inference data*.
 
   Arguments:
     axis: Integer, the axis that should be normalized (typically the features
@@ -117,6 +134,7 @@
             across all examples), and finally apply gamma and/or beta. If
             `None`, no adjustment is applied. Cannot be specified if
             virtual_batch_size is specified.
+
   Call arguments:
     inputs: Input tensor (of any rank).
     training: Python boolean indicating whether the layer should behave in
@@ -125,21 +143,13 @@
         variance of the current batch of inputs.
       - `training=False`: The layer will normalize its inputs using the mean and
         variance of its moving statistics, learned during training.
+
   Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of
     integers, does not include the samples axis) when using this layer as the
     first layer in a model.
+
   Output shape: Same shape as input.  {{TRAINABLE_ATTRIBUTE_NOTE}}
-  Normalization equations: Consider the intermediate activations \(x\) of a
-    mini-batch of size
-    \\(m\\):  We can compute the mean and variance of the batch  \\({\mu_B} =
-      \frac{1}{m} \sum_{i=1}^{m} {x_i}\\)  \\({\sigma_B^2} = \frac{1}{m}
-      \sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\)  and then compute a normalized
-      \\(x\\), including a small factor \\({\epsilon}\\) for numerical
-      stability.  \\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 +
-      \epsilon}}\\)  And finally \\(\hat{x}\) is linearly transformed by
-      \({\gamma}\\)
-    and \\({\beta}\\), which are learned parameters:  \\({y_i} = {\gamma *
-      \hat{x_i} + \beta}\\)
+
   Reference:
     - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
   """
@@ -480,7 +490,8 @@
   def _assign_moving_average(self, variable, value, momentum, inputs_size):
     with K.name_scope('AssignMovingAvg') as scope:
       with ops.colocate_with(variable):
-        decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay')
+        decay = ops.convert_to_tensor_v2_with_dispatch(
+            1.0 - momentum, name='decay')
         if decay.dtype != variable.dtype.base_dtype:
           decay = math_ops.cast(decay, variable.dtype.base_dtype)
         update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
@@ -585,7 +596,7 @@
                                                   lambda: self.momentum,
                                                   lambda: 1.0)
         else:
-          momentum = ops.convert_to_tensor_v2(self.momentum)
+          momentum = ops.convert_to_tensor_v2_with_dispatch(self.momentum)
 
       def mean_update():
         """Update self.moving_mean with the most recent data point."""
@@ -787,10 +798,11 @@
       moving_variance = self.moving_variance
 
       mean = control_flow_util.smart_cond(
-          training, lambda: mean, lambda: ops.convert_to_tensor_v2(moving_mean))
+          training, lambda: mean,
+          lambda: ops.convert_to_tensor_v2_with_dispatch(moving_mean))
       variance = control_flow_util.smart_cond(
           training, lambda: variance,
-          lambda: ops.convert_to_tensor_v2(moving_variance))
+          lambda: ops.convert_to_tensor_v2_with_dispatch(moving_variance))
 
       if self.virtual_batch_size is not None:
         # This isn't strictly correct since in ghost batch norm, you are
diff --git a/tensorflow/python/keras/layers/preprocessing/category_crossing.py b/tensorflow/python/keras/layers/preprocessing/category_crossing.py
index bdb29d2..747a105 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_crossing.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_crossing.py
@@ -143,7 +143,7 @@
 
   def _preprocess_input(self, inp):
     if isinstance(inp, (list, tuple, np.ndarray)):
-      inp = ops.convert_to_tensor(inp)
+      inp = ops.convert_to_tensor_v2_with_dispatch(inp)
     if inp.shape.rank == 1:
       inp = array_ops.expand_dims(inp, axis=-1)
     return inp
diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding.py b/tensorflow/python/keras/layers/preprocessing/category_encoding.py
index 9554017..87112fa 100644
--- a/tensorflow/python/keras/layers/preprocessing/category_encoding.py
+++ b/tensorflow/python/keras/layers/preprocessing/category_encoding.py
@@ -269,7 +269,7 @@
 
   def call(self, inputs, count_weights=None):
     if isinstance(inputs, (list, np.ndarray)):
-      inputs = ops.convert_to_tensor_v2(inputs)
+      inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
     if inputs.shape.rank == 1:
       inputs = array_ops.expand_dims(inputs, 1)
 
diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py
index a6de075..ea8d6f0 100644
--- a/tensorflow/python/keras/layers/preprocessing/hashing.py
+++ b/tensorflow/python/keras/layers/preprocessing/hashing.py
@@ -154,7 +154,7 @@
 
   def _preprocess_single_input(self, inp):
     if isinstance(inp, (list, tuple, np.ndarray)):
-      inp = ops.convert_to_tensor(inp)
+      inp = ops.convert_to_tensor_v2_with_dispatch(inp)
     return inp
 
   def _preprocess_inputs(self, inputs):
diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
index 87a18db..6e0098c 100644
--- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
+++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py
@@ -681,7 +681,7 @@
         if output_shape_value is not None:
           output_shape = output_shape_value
 
-    output_shape = ops.convert_to_tensor_v2(
+    output_shape = ops.convert_to_tensor_v2_with_dispatch(
         output_shape, dtypes.int32, name='output_shape')
 
     if not output_shape.get_shape().is_compatible_with([2]):
diff --git a/tensorflow/python/keras/layers/preprocessing/normalization.py b/tensorflow/python/keras/layers/preprocessing/normalization.py
index 4b75def..b8cf233 100644
--- a/tensorflow/python/keras/layers/preprocessing/normalization.py
+++ b/tensorflow/python/keras/layers/preprocessing/normalization.py
@@ -145,7 +145,7 @@
     super(Normalization, self).build(input_shape)
 
   def call(self, inputs):
-    inputs = ops.convert_to_tensor_v2(inputs)
+    inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
     if inputs.shape.rank == 1:
       inputs = array_ops.expand_dims(inputs, 1)
     # If the inputs are not floats, cast them to floats. This avoids issues
diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py
index 3329f32..c72b825 100644
--- a/tensorflow/python/keras/layers/preprocessing/table_utils.py
+++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py
@@ -62,8 +62,10 @@
       raise RuntimeError("Size mismatch between values and key arrays. "
                          "Keys had size %s, values had size %s." %
                          (len(keys), len(values)))
-    keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype)  # pylint: disable=protected-access
-    values = ops.convert_to_tensor(values, dtype=self.table._value_dtype)  # pylint: disable=protected-access
+    keys = ops.convert_to_tensor_v2_with_dispatch(
+        keys, dtype=self.table._key_dtype)  # pylint: disable=protected-access
+    values = ops.convert_to_tensor_v2_with_dispatch(
+        values, dtype=self.table._value_dtype)  # pylint: disable=protected-access
     if values.shape.ndims != 1:
       raise ValueError("`values` must be 1-dimensional, got an input with "
                        " %s dimensions." % values.shape.ndims)
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
index 2cc8bc2..36e326b 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
@@ -367,7 +367,7 @@
     # on an implicit call to `build` in the base layer's `adapt`, since
     # preprocessing changes the input shape.
     if isinstance(data, (list, tuple, np.ndarray)):
-      data = ops.convert_to_tensor(data)
+      data = ops.convert_to_tensor_v2_with_dispatch(data)
 
     if isinstance(data, ops.Tensor):
       if data.shape.rank == 1:
@@ -566,7 +566,7 @@
 
   def call(self, inputs):
     if isinstance(inputs, (list, tuple, np.ndarray)):
-      inputs = ops.convert_to_tensor(inputs)
+      inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
 
     self._called = True
     inputs = self._preprocess(inputs)
diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py
index a2ed714..c7a9a87 100644
--- a/tensorflow/python/keras/layers/recurrent_v2.py
+++ b/tensorflow/python/keras/layers/recurrent_v2.py
@@ -387,9 +387,9 @@
       else:
         logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
 
-    # TODO(b/162616551): Remove all compat statements after 08/20/2020.
+    # TODO(b/162616551): Remove all compat statements after 9/2/2020.
     # This follows b/161915509 and is mainly to test the stateless Case op.
-    if compat.forward_compatible(2020, 8, 20):
+    if compat.forward_compatible(2020, 9, 2):
       # The first two attributes are added to support TFLite use case.
       supportive_attributes = {
           'time_major': time_major,
@@ -483,7 +483,7 @@
     if dropout_mask is not None:
       inputs = inputs * dropout_mask[0]
 
-    if compat.forward_compatible(2020, 8, 20):
+    if compat.forward_compatible(2020, 9, 2):
       gru_kwargs = {
           'inputs': inputs,
           'init_h': _read_variable_value(initial_state[0]),
@@ -797,7 +797,7 @@
         true_fn=cudnn_gru_fn,
         false_fn=standard_gru_fn)
 
-  if compat.forward_compatible(2020, 8, 20):
+  if compat.forward_compatible(2020, 9, 2):
     # Chooses the implementation dynamicly based on the running device.
     (last_output, outputs, new_h,
      runtime) = control_flow_ops.execute_fn_for_device(
@@ -1141,7 +1141,7 @@
       else:
         logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
 
-    if compat.forward_compatible(2020, 8, 20):
+    if compat.forward_compatible(2020, 9, 2):
       # The first two attributes are added to support TFLite use case.
       supportive_attributes = {
           'time_major': time_major,
@@ -1202,7 +1202,7 @@
       dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
       if dropout_mask is not None:
         inputs = inputs * dropout_mask[0]
-      if compat.forward_compatible(2020, 8, 20):
+      if compat.forward_compatible(2020, 9, 2):
         lstm_kwargs = {
             'inputs':
                 inputs,
@@ -1633,7 +1633,7 @@
         true_fn=cudnn_lstm_fn,
         false_fn=stardard_lstm_fn)
 
-  if compat.forward_compatible(2020, 8, 20):
+  if compat.forward_compatible(2020, 9, 2):
     # Chooses the implementation dynamicly based on the running device.
     (last_output, outputs, new_h, new_c,
      runtime) = control_flow_ops.execute_fn_for_device(
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
index b0fd518..19ea3dc 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
@@ -40,8 +40,10 @@
 
   def testResidualWrapper(self):
     wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
-    x = ops.convert_to_tensor_v2(np.array([[1., 1., 1.]]), dtype="float32")
-    m = ops.convert_to_tensor_v2(np.array([[0.1, 0.1, 0.1]]), dtype="float32")
+    x = ops.convert_to_tensor_v2_with_dispatch(
+        np.array([[1., 1., 1.]]), dtype="float32")
+    m = ops.convert_to_tensor_v2_with_dispatch(
+        np.array([[0.1, 0.1, 0.1]]), dtype="float32")
     base_cell = rnn_cell_impl.GRUCell(
         3, kernel_initializer=init_ops.constant_initializer(0.5),
         bias_initializer=init_ops.constant_initializer(0.5))
@@ -62,9 +64,10 @@
 
   def testResidualWrapperWithSlice(self):
     wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
-    x = ops.convert_to_tensor_v2(
+    x = ops.convert_to_tensor_v2_with_dispatch(
         np.array([[1., 1., 1., 1., 1.]]), dtype="float32")
-    m = ops.convert_to_tensor_v2(np.array([[0.1, 0.1, 0.1]]), dtype="float32")
+    m = ops.convert_to_tensor_v2_with_dispatch(
+        np.array([[0.1, 0.1, 0.1]]), dtype="float32")
     base_cell = rnn_cell_impl.GRUCell(
         3, kernel_initializer=init_ops.constant_initializer(0.5),
         bias_initializer=init_ops.constant_initializer(0.5))
@@ -116,7 +119,8 @@
     base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell")
     rnn_cell = wrapper(base_cell)
     rnn_layer = layers.RNN(rnn_cell)
-    inputs = ops.convert_to_tensor_v2([[[1]]], dtype=dtypes.float32)
+    inputs = ops.convert_to_tensor_v2_with_dispatch([[[1]]],
+                                                    dtype=dtypes.float32)
     rnn_layer(inputs)
 
     wrapper_name = generic_utils.to_snake_case(wrapper.__name__)
@@ -140,8 +144,8 @@
       base_cell = rnn_cell_impl.MultiRNNCell(
           [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
     rnn_cell = wrapper(base_cell)
-    inputs = ops.convert_to_tensor_v2([[1]], dtype=dtypes.float32)
-    state = ops.convert_to_tensor_v2([[1]], dtype=dtypes.float32)
+    inputs = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32)
+    state = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32)
     _ = rnn_cell(inputs, [state, state])
     weights = base_cell._cells[0].weights
     self.assertLen(weights, expected_len=2)
diff --git a/tensorflow/python/keras/layers/subclassed_layers_test.py b/tensorflow/python/keras/layers/subclassed_layers_test.py
index 6adeb09..572ce85 100644
--- a/tensorflow/python/keras/layers/subclassed_layers_test.py
+++ b/tensorflow/python/keras/layers/subclassed_layers_test.py
@@ -37,7 +37,7 @@
     class BuildConstantLayer(keras.layers.Layer):
 
       def build(self, input_shape):
-        self.b = ops.convert_to_tensor_v2(2.0)
+        self.b = ops.convert_to_tensor_v2_with_dispatch(2.0)
 
       def call(self, inputs):
         return self.b * inputs
@@ -46,7 +46,7 @@
     model = testing_utils.get_model_from_layers(
         [layer, keras.layers.Dense(1)], input_shape=(1,))
 
-    x = ops.convert_to_tensor_v2([[3.0]])
+    x = ops.convert_to_tensor_v2_with_dispatch([[3.0]])
     self.assertEqual(
         tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly())
     self.assertEqual(
@@ -58,10 +58,10 @@
     class BuildDerivedConstantLayer(keras.layers.Layer):
 
       def build(self, input_shape):
-        a = ops.convert_to_tensor_v2(1.0)
+        a = ops.convert_to_tensor_v2_with_dispatch(1.0)
         b = 2.0 * a
         self.variable = variables.Variable(b)
-        self.constant = ops.convert_to_tensor_v2(self.variable)
+        self.constant = ops.convert_to_tensor_v2_with_dispatch(self.variable)
 
       def call(self, inputs):
         return self.variable * self.constant * inputs
@@ -70,7 +70,7 @@
     model = testing_utils.get_model_from_layers(
         [layer, keras.layers.Dense(1)], input_shape=(1,))
 
-    x = ops.convert_to_tensor_v2([[3.0]])
+    x = ops.convert_to_tensor_v2_with_dispatch([[3.0]])
     self.assertEqual(
         tf_utils.is_symbolic_tensor(model(x)), not context.executing_eagerly())
     self.assertEqual(
diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
index e128323..0293233 100644
--- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
+++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
@@ -637,7 +637,7 @@
     self.assertAllEqual(model(ones), 3.0 * ones)
 
   def test_numerical_correctness_simple(self):
-    x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
+    x = ops.convert_to_tensor_v2_with_dispatch([[-1., 0., -2., 1.]])
     inputs = keras.Input(shape=(4,))
     outputs = gen_nn_ops.relu(inputs)
     model = keras.Model(inputs, outputs)
@@ -645,7 +645,7 @@
     self.assertAllClose(y, [[0., 0., 0., 1.]])
 
   def test_numerical_correctness_with_attrs(self):
-    x = ops.convert_to_tensor_v2([[1.5, 1.5], [2.5, 3.5]])
+    x = ops.convert_to_tensor_v2_with_dispatch([[1.5, 1.5], [2.5, 3.5]])
     inputs = keras.Input(shape=(2,))
     outputs = math_ops.reduce_mean(inputs, axis=1)
     model = keras.Model(inputs, outputs)
@@ -653,7 +653,7 @@
     self.assertAllClose(y, [1.5, 3.])
 
   def test_numerical_correctness_serialization(self):
-    x = ops.convert_to_tensor_v2([[-1., 0., -2., 1.]])
+    x = ops.convert_to_tensor_v2_with_dispatch([[-1., 0., -2., 1.]])
     inputs = keras.Input(shape=(4,))
     outputs = gen_nn_ops.relu(inputs)
     model1 = keras.Model(inputs, outputs)
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index 671fe65..4f0eee8 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -44,7 +44,6 @@
 from tensorflow.python.platform import test
 from tensorflow.python.training.tracking import util as trackable_util
 from tensorflow.python.util import nest
-from tensorflow.python.util import object_identity
 
 
 class _RNNCellWithConstants(keras.layers.Layer):
@@ -130,10 +129,11 @@
 
     # check whether the model variables are present in the
     # trackable list of objects
-    checkpointed_objects = object_identity.ObjectIdentitySet(
-        trackable_util.list_objects(model))
+    checkpointed_object_ids = {
+        id(o) for o in trackable_util.list_objects(model)
+    }
     for v in model.variables:
-      self.assertIn(v, checkpointed_objects)
+      self.assertIn(id(v), checkpointed_object_ids)
 
   def test_timedistributed_static_batch_size(self):
     model = keras.models.Sequential()
@@ -492,10 +492,11 @@
 
       # check whether the model variables are present in the
       # trackable list of objects
-      checkpointed_objects = object_identity.ObjectIdentitySet(
-          trackable_util.list_objects(model))
+      checkpointed_object_ids = {
+          id(o) for o in trackable_util.list_objects(model)
+      }
       for v in model.variables:
-        self.assertIn(v, checkpointed_objects)
+        self.assertIn(id(v), checkpointed_object_ids)
 
       # test compute output shape
       ref_shape = model.layers[-1].output.shape
@@ -1030,10 +1031,11 @@
 
     # check whether the model variables are present in the
     # trackable list of objects
-    checkpointed_objects = object_identity.ObjectIdentitySet(
-        trackable_util.list_objects(model))
+    checkpointed_object_ids = {
+        id(o) for o in trackable_util.list_objects(model)
+    }
     for v in model.variables:
-      self.assertIn(v, checkpointed_objects)
+      self.assertIn(id(v), checkpointed_object_ids)
 
     # test compute output shape
     ref_shape = model.layers[-1].output.shape
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index bda3289..6b74121 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -1189,7 +1189,7 @@
   Returns:
     Mean squared error values. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   return K.mean(math_ops.squared_difference(y_pred, y_true), axis=-1)
 
@@ -1222,7 +1222,7 @@
   Returns:
     Mean absolute error values. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   return K.mean(math_ops.abs(y_pred - y_true), axis=-1)
 
@@ -1257,7 +1257,7 @@
   Returns:
     Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   diff = math_ops.abs(
       (y_true - y_pred) / K.maximum(math_ops.abs(y_true), K.epsilon()))
@@ -1296,7 +1296,7 @@
   Returns:
     Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   first_log = math_ops.log(K.maximum(y_pred, K.epsilon()) + 1.)
   second_log = math_ops.log(K.maximum(y_true, K.epsilon()) + 1.)
@@ -1344,7 +1344,7 @@
   Returns:
      Squared hinge loss values. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   y_true = _maybe_convert_labels(y_true)
   return K.mean(
@@ -1377,7 +1377,7 @@
   Returns:
     Hinge loss values. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   y_true = _maybe_convert_labels(y_true)
   return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1)
@@ -1409,7 +1409,7 @@
   Returns:
     Categorical hinge loss values.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   pos = math_ops.reduce_sum(y_true * y_pred, axis=-1)
   neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1)
@@ -1444,7 +1444,7 @@
   delta = math_ops.cast(delta, dtype=K.floatx())
   error = math_ops.subtract(y_pred, y_true)
   abs_error = math_ops.abs(error)
-  half = ops.convert_to_tensor_v2(0.5, dtype=abs_error.dtype)
+  half = ops.convert_to_tensor_v2_with_dispatch(0.5, dtype=abs_error.dtype)
   return K.mean(
       array_ops.where_v2(
           abs_error <= delta, half * math_ops.pow(error, 2),
@@ -1481,7 +1481,7 @@
   Returns:
     Logcosh error values. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
 
   def _logcosh(x):
@@ -1518,9 +1518,10 @@
   Returns:
     Categorical crossentropy loss value.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
-  label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx())
+  label_smoothing = ops.convert_to_tensor_v2_with_dispatch(
+      label_smoothing, dtype=K.floatx())
 
   def _smooth_labels():
     num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype)
@@ -1557,7 +1558,7 @@
   Returns:
     Sparse categorical crossentropy loss value.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   return K.sparse_categorical_crossentropy(
       y_true, y_pred, from_logits=from_logits, axis=axis)
@@ -1588,9 +1589,10 @@
   Returns:
     Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
-  label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx())
+  label_smoothing = ops.convert_to_tensor_v2_with_dispatch(
+      label_smoothing, dtype=K.floatx())
 
   def _smooth_labels():
     return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing
@@ -1638,7 +1640,7 @@
   Raises:
     TypeError: If `y_true` cannot be cast to the `y_pred.dtype`.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   y_true = K.clip(y_true, K.epsilon(), 1)
   y_pred = K.clip(y_pred, K.epsilon(), 1)
@@ -1674,7 +1676,7 @@
   Raises:
     InvalidArgumentError: If `y_true` and `y_pred` have incompatible shapes.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1)
 
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 34213c8..4de49e6 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -95,16 +95,19 @@
     p = backend.placeholder()
     o = losses.categorical_crossentropy(t, p)
 
-    t_val = ops.convert_to_tensor_v2([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
-    p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06],
-                                      [.05, .01, .94]])
+    t_val = ops.convert_to_tensor_v2_with_dispatch([[1., 0., 0.], [0., 1., 0.],
+                                                    [0., 0., 1.]])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05],
+                                                    [.05, .89, .06],
+                                                    [.05, .01, .94]])
     f = backend.function([t, p], o)
 
     result = f([t_val, p_val])
     self.assertArrayNear(result, [.105, .116, .062], 1e-3)
 
     # from logits
-    p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.],
+                                                    [2., 3., 5.]])
     o = losses.categorical_crossentropy(t, p, from_logits=True)
     f = backend.function([t, p], o)
 
@@ -133,16 +136,18 @@
     p = backend.placeholder()
     o = losses.sparse_categorical_crossentropy(t, p)
 
-    t_val = ops.convert_to_tensor_v2([0, 1, 2])
-    p_val = ops.convert_to_tensor_v2([[.9, .05, .05], [.05, .89, .06],
-                                      [.05, .01, .94]])
+    t_val = ops.convert_to_tensor_v2_with_dispatch([0, 1, 2])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[.9, .05, .05],
+                                                    [.05, .89, .06],
+                                                    [.05, .01, .94]])
     f = backend.function([t, p], o)
 
     result = f([t_val, p_val])
     self.assertArrayNear(result, [.105, .116, .062], 1e-3)
 
     # from logits
-    p_val = ops.convert_to_tensor_v2([[8., 1., 1.], [0., 9., 1.], [2., 3., 5.]])
+    p_val = ops.convert_to_tensor_v2_with_dispatch([[8., 1., 1.], [0., 9., 1.],
+                                                    [2., 3., 5.]])
     o = losses.sparse_categorical_crossentropy(t, p, from_logits=True)
     f = backend.function([t, p], o)
 
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index b3f391c..eea1881 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -963,7 +963,7 @@
       result = self.accumulator[0]
     else:
       result = self.accumulator
-    return ops.convert_to_tensor_v2(result)
+    return ops.convert_to_tensor_v2_with_dispatch(result)
 
   def reset_states(self):
     num_thresholds = len(to_list(self.thresholds))
@@ -3239,7 +3239,7 @@
   Returns:
     Binary accuracy values. shape = `[batch_size, d0, .. dN-1]`
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
   threshold = math_ops.cast(threshold, y_pred.dtype)
   y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype)
   return K.mean(math_ops.equal(y_true, y_pred), axis=-1)
@@ -3297,8 +3297,8 @@
   Returns:
     Sparse categorical accuracy values.
   """
-  y_pred = ops.convert_to_tensor_v2(y_pred)
-  y_true = ops.convert_to_tensor_v2(y_true)
+  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
+  y_true = ops.convert_to_tensor_v2_with_dispatch(y_true)
   y_pred_rank = y_pred.shape.ndims
   y_true_rank = y_true.shape.ndims
   # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
@@ -3364,8 +3364,8 @@
   Returns:
     Sparse top K categorical accuracy value.
   """
-  y_pred_rank = ops.convert_to_tensor_v2(y_pred).shape.ndims
-  y_true_rank = ops.convert_to_tensor_v2(y_true).shape.ndims
+  y_pred_rank = ops.convert_to_tensor_v2_with_dispatch(y_pred).shape.ndims
+  y_true_rank = ops.convert_to_tensor_v2_with_dispatch(y_true).shape.ndims
   # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
   if (y_true_rank is not None) and (y_pred_rank is not None):
     if y_pred_rank > 2:
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 7b339fc..554609d 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -71,7 +71,7 @@
       self.assertEqual(self.evaluate(m.total), 100)
 
       # check update_state() and result() + state accumulation + tensor input
-      update_op = m.update_state(ops.convert_to_tensor_v2([1, 5]))
+      update_op = m.update_state(ops.convert_to_tensor_v2_with_dispatch([1, 5]))
       self.evaluate(update_op)
       self.assertAlmostEqual(self.evaluate(m.result()), 106)
       self.assertEqual(self.evaluate(m.total), 106)  # 100 + 1 + 5
diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
index 2077006..d0d6442 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
@@ -126,7 +126,7 @@
       raise ValueError(
           'Incompatible type conversion requested to type {!r} for variable '
           'of type {!r}'.format(dtype.name, self.dtype.name))
-    val = ops.convert_to_tensor_v2(
+    val = ops.convert_to_tensor_v2_with_dispatch(
         self._variable, dtype=self._variable.dtype, name=name)
     return math_ops.cast(val, self.dtype)
 
diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
index 9a9d174..75fe3d9 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py
@@ -124,10 +124,10 @@
   def testGetScaledLoss(self):
     opt = gradient_descent.SGD(2.0)
     opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.)
-    loss = ops.convert_to_tensor_v2(5.)
+    loss = ops.convert_to_tensor_v2_with_dispatch(5.)
     self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
     self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)()))
-    loss = ops.convert_to_tensor_v2(5., dtype='float16')
+    loss = ops.convert_to_tensor_v2_with_dispatch(5., dtype='float16')
     self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
     self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)()))
 
@@ -135,8 +135,8 @@
     opt = gradient_descent.SGD(2.0)
     opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2)
     scaled_grads = [
-        ops.convert_to_tensor_v2(3.), None,
-        ops.convert_to_tensor_v2(-4., dtype='float16')
+        ops.convert_to_tensor_v2_with_dispatch(3.), None,
+        ops.convert_to_tensor_v2_with_dispatch(-4., dtype='float16')
     ]
     grads = opt.get_unscaled_gradients(scaled_grads)
     grads = [self.evaluate(g) if g is not None else g for g in grads]
@@ -146,9 +146,10 @@
     opt = gradient_descent.SGD(2.0)
     opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2)
     sparse_scaled_grad = ops.IndexedSlices(
-        ops.convert_to_tensor_v2([[4., 2.], [8., 5.]]),
-        ops.convert_to_tensor_v2([1, 3], dtype='int32'),
-        dense_shape=ops.convert_to_tensor_v2([5, 2], dtype='int32'))
+        ops.convert_to_tensor_v2_with_dispatch([[4., 2.], [8., 5.]]),
+        ops.convert_to_tensor_v2_with_dispatch([1, 3], dtype='int32'),
+        dense_shape=ops.convert_to_tensor_v2_with_dispatch([5, 2],
+                                                           dtype='int32'))
     sparse_grad = opt.get_unscaled_gradients([sparse_scaled_grad])[0]
     self.assertIsInstance(sparse_grad, ops.IndexedSlices)
     self.assertAllEqual([[2., 1.], [4., 2.5]],
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py
index 592057f..c8acd86 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py
@@ -543,7 +543,12 @@
   passed to the layer constructor. If no global policy is set, layers will
   instead default to a Policy constructed from `tf.keras.backend.floatx()`.
 
-  See `keras.mixed_precision.experimental.Policy` for more information.
+  Only floating point policies can be set as the global policy, such as
+  `'float32'` and `'mixed_float16'`. Non-floating point policies such as
+  `'int32'` and `'complex64'` cannot be set as the global policy because most
+  layers do not support such policies.
+
+  See `tf.keras.mixed_precision.experimental.Policy` for more information.
 
   Args:
     policy: A Policy, or a string that will be converted to a Policy..
@@ -559,6 +564,12 @@
   is_mixed_policy = policy is not None and policy.should_cast_variables
   if is_mixed_policy:
     _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
+  if (policy is not None and policy.compute_dtype is not None and
+      not dtypes.as_dtype(policy.compute_dtype).is_floating):
+    raise ValueError('set_policy can only be used to set the global policy to '
+                     'floating-point policies, such as "float32" and '
+                     '"mixed_float16", but got policy: %s'
+                     % (policy.name,))
   _global_policy = policy
   mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy
 
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
index 94880a9..060f80f 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
@@ -156,6 +156,21 @@
       mp_policy.set_policy(None)
 
   @testing_utils.enable_v2_dtype_behavior
+  def test_global_policy_dtype_error(self):
+    with self.assertRaisesRegex(
+        ValueError,
+        'set_policy can only be used to set the global policy to '
+        'floating-point policies, such as "float32" and "mixed_float16", but '
+        'got policy: int32'):
+      mp_policy.set_policy('int32')
+    with self.assertRaisesRegex(
+        ValueError,
+        'set_policy can only be used to set the global policy to '
+        'floating-point policies, such as "float32" and "mixed_float16", but '
+        'got policy: complex64'):
+      mp_policy.set_policy(mp_policy.Policy('complex64'))
+
+  @testing_utils.enable_v2_dtype_behavior
   def test_loss_scale_warning(self):
     with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
       mp_policy.Policy('float32', loss_scale=2.)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/test_util.py b/tensorflow/python/keras/mixed_precision/experimental/test_util.py
index 937b378..c0d9cbf 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/test_util.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/test_util.py
@@ -55,7 +55,7 @@
       if expected_dtype:
         assert dx.dtype == expected_dtype, (
             'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype))
-      expected_tensor = ops.convert_to_tensor_v2(
+      expected_tensor = ops.convert_to_tensor_v2_with_dispatch(
           expected_gradient, dtype=dx.dtype, name='expected_gradient')
       # Control dependency is to ensure input is available. It's possible the
       # dataset will throw a StopIteration to indicate there is no more data, in
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
index 8c895ae..404b3f8 100644
--- a/tensorflow/python/keras/optimizer_v2/adadelta.py
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -101,7 +101,8 @@
     super(Adadelta, self)._prepare_local(var_device, var_dtype, apply_state)
     apply_state[(var_device, var_dtype)].update(
         dict(
-            epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
+            epsilon=ops.convert_to_tensor_v2_with_dispatch(
+                self.epsilon, var_dtype),
             rho=array_ops.identity(self._get_hyper('rho', var_dtype))))
 
   def set_weights(self, weights):
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
index ba76b83..4d3294a 100644
--- a/tensorflow/python/keras/optimizer_v2/adagrad.py
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -87,7 +87,8 @@
     super(Adagrad, self)._prepare_local(var_device, var_dtype, apply_state)
     apply_state[(var_device, var_dtype)].update(
         dict(
-            epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
+            epsilon=ops.convert_to_tensor_v2_with_dispatch(
+                self.epsilon, var_dtype),
             neg_lr_t=-apply_state[(var_device, var_dtype)]['lr_t'],
             zero=array_ops.zeros((), dtype=dtypes.int64)))
 
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
index 1fccd11..e4896fd 100644
--- a/tensorflow/python/keras/optimizer_v2/adam.py
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -144,7 +144,8 @@
     apply_state[(var_device, var_dtype)].update(
         dict(
             lr=lr,
-            epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
+            epsilon=ops.convert_to_tensor_v2_with_dispatch(
+                self.epsilon, var_dtype),
             beta_1_t=beta_1_t,
             beta_1_power=beta_1_power,
             one_minus_beta_1_t=1 - beta_1_t,
@@ -396,7 +397,8 @@
     apply_state[(var_device, var_dtype)].update(
         dict(
             lr=lr,
-            epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
+            epsilon=ops.convert_to_tensor_v2_with_dispatch(
+                self.epsilon, var_dtype),
             beta_1_t=beta_1_t,
             beta_1_power=beta_1_power,
             one_minus_beta_1_t=1 - beta_1_t,
diff --git a/tensorflow/python/keras/optimizer_v2/adamax.py b/tensorflow/python/keras/optimizer_v2/adamax.py
index 3f4312c..26cc59b 100644
--- a/tensorflow/python/keras/optimizer_v2/adamax.py
+++ b/tensorflow/python/keras/optimizer_v2/adamax.py
@@ -122,7 +122,8 @@
     apply_state[(var_device, var_dtype)].update(
         dict(
             neg_scaled_lr=-lr_t / (1 - beta_1_power),
-            epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
+            epsilon=ops.convert_to_tensor_v2_with_dispatch(
+                self.epsilon, var_dtype),
             beta_1_t=beta_1_t,
             beta_1_power=beta_1_power,
             one_minus_beta_1_t=1 - beta_1_t,
diff --git a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
index 4dcff3d..30b4f21 100644
--- a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
+++ b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
@@ -143,7 +143,7 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "ExponentialDecay") as name:
-      initial_learning_rate = ops.convert_to_tensor_v2(
+      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
           self.initial_learning_rate, name="initial_learning_rate")
       dtype = initial_learning_rate.dtype
       decay_steps = math_ops.cast(self.decay_steps, dtype)
@@ -237,11 +237,11 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "PiecewiseConstant"):
-      boundaries = nest.map_structure(ops.convert_to_tensor_v2,
+      boundaries = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
                                       nest.flatten(self.boundaries))
-      values = nest.map_structure(ops.convert_to_tensor_v2,
+      values = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
                                   nest.flatten(self.values))
-      x_recomp = ops.convert_to_tensor_v2(step)
+      x_recomp = ops.convert_to_tensor_v2_with_dispatch(step)
       for i, b in enumerate(boundaries):
         if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
           # We cast the boundaries to have the same type as the step
@@ -374,7 +374,7 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "PolynomialDecay") as name:
-      initial_learning_rate = ops.convert_to_tensor_v2(
+      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
           self.initial_learning_rate, name="initial_learning_rate")
       dtype = initial_learning_rate.dtype
       end_learning_rate = math_ops.cast(self.end_learning_rate, dtype)
@@ -494,7 +494,7 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "InverseTimeDecay") as name:
-      initial_learning_rate = ops.convert_to_tensor_v2(
+      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
           self.initial_learning_rate, name="initial_learning_rate")
       dtype = initial_learning_rate.dtype
       decay_steps = math_ops.cast(self.decay_steps, dtype)
@@ -588,7 +588,7 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "CosineDecay"):
-      initial_learning_rate = ops.convert_to_tensor_v2(
+      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
           self.initial_learning_rate, name="initial_learning_rate")
       dtype = initial_learning_rate.dtype
       decay_steps = math_ops.cast(self.decay_steps, dtype)
@@ -687,7 +687,7 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "SGDRDecay") as name:
-      initial_learning_rate = ops.convert_to_tensor_v2(
+      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
           self.initial_learning_rate, name="initial_learning_rate")
       dtype = initial_learning_rate.dtype
       first_decay_steps = math_ops.cast(self.first_decay_steps, dtype)
@@ -824,7 +824,7 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "LinearCosineDecay") as name:
-      initial_learning_rate = ops.convert_to_tensor_v2(
+      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
           self.initial_learning_rate, name="initial_learning_rate")
       dtype = initial_learning_rate.dtype
       decay_steps = math_ops.cast(self.decay_steps, dtype)
@@ -950,7 +950,7 @@
 
   def __call__(self, step):
     with ops.name_scope_v2(self.name or "NoisyLinearCosineDecay") as name:
-      initial_learning_rate = ops.convert_to_tensor_v2(
+      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
           self.initial_learning_rate, name="initial_learning_rate")
       dtype = initial_learning_rate.dtype
       decay_steps = math_ops.cast(self.decay_steps, dtype)
diff --git a/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py
index ad28056..ab8e4f5 100644
--- a/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py
+++ b/tensorflow/python/keras/optimizer_v2/legacy_learning_rate_decay.py
@@ -148,10 +148,11 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  boundaries = nest.map_structure(ops.convert_to_tensor_v2,
+  boundaries = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
                                   nest.flatten(boundaries))
-  values = nest.map_structure(ops.convert_to_tensor_v2, nest.flatten(values))
-  x_recomp = ops.convert_to_tensor(x)
+  values = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
+                              nest.flatten(values))
+  x_recomp = ops.convert_to_tensor_v2_with_dispatch(x)
   # Avoid explicit conversion to x's dtype. This could result in faulty
   # comparisons, for example if floats are converted to integers.
   for i, b in enumerate(boundaries):
diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py
index 090eaba..550db0f 100644
--- a/tensorflow/python/keras/optimizer_v2/nadam.py
+++ b/tensorflow/python/keras/optimizer_v2/nadam.py
@@ -122,7 +122,7 @@
     apply_state[(var_device, var_dtype)] = dict(
         lr_t=lr_t,
         neg_lr_t=-lr_t,
-        epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
+        epsilon=ops.convert_to_tensor_v2_with_dispatch(self.epsilon, var_dtype),
         beta_1_t=beta_1_t,
         beta_2_t=beta_2_t,
         m_t=m_t,
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index e6b4458..1227e66 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -304,26 +304,29 @@
     ```
 
     Args:
-      name: A non-empty string.  The name to use for accumulators created
-        for the optimizer.
+      name: String. The name to use for momentum accumulator weights created
+        by the optimizer.
       gradient_aggregator: The function to use to aggregate gradients across
         devices (when using `tf.distribute.Strategy`). If `None`, defaults to
         summing the gradients across devices. The function should accept and
         return a list of `(gradient, variable)` tuples.
-      gradient_transformers: (Optional). List of functions to use to transform
-        gradients before applying updates to `Variable`s. The functions are
+      gradient_transformers: Optional. List of functions to use to transform
+        gradients before applying updates to Variables. The functions are
         applied after `gradient_aggregator`. The functions should accept and
         return a list of `(gradient, variable)` tuples.
-      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
-        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
-        gradients by value, `decay` is included for backward compatibility to
-        allow time inverse decay of learning rate. `lr` is included for backward
-        compatibility, recommended to use `learning_rate` instead.
+      **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
+        `clipnorm`, `global_clipnorm`.
+        If `clipvalue` (float) is set, the gradient of each weight
+        is clipped to be no higher than this value.
+        If `clipnorm` (float) is set, the gradient of each weight
+        is individually clipped so that its norm is no higher than this value.
+        If `global_clipnorm` (float) is set the gradient of all weights is
+        clipped so that their global norm is no higher than this value.
 
     Raises:
-      ValueError: If name is malformed.
+      ValueError: in case of any invalid argument.
     """
-    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"}
+    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"}
     for k in kwargs:
       if k not in allowed_kwargs:
         raise TypeError("Unexpected keyword argument "
@@ -370,6 +373,11 @@
       gradient_transformers = []
     self.gradient_transformers = gradient_transformers
     self.clipnorm = kwargs.pop("clipnorm", None)
+    self.global_clipnorm = kwargs.pop("global_clipnorm", None)
+    if self.clipnorm is not None and self.global_clipnorm is not None:
+      raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, "
+                       "passed `clipnorm` {}, `global_clipnorm` {}".format(
+                           self.clipnorm, self.global_clipnorm))
     self.clipvalue = kwargs.pop("clipvalue", None)
 
   @property
@@ -377,6 +385,11 @@
     """`float` or `None`. If set, clips gradients to a maximum norm."""
     return self._clipnorm
 
+  @property
+  def global_clipnorm(self):
+    """`float` or `None`. If set, clips gradients to a maximum norm."""
+    return self._global_clipnorm
+
   @clipnorm.setter
   def clipnorm(self, val):
     if val is not None and self.gradient_transformers:
@@ -387,6 +400,16 @@
     self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn(
         self._clipnorm)
 
+  @global_clipnorm.setter
+  def global_clipnorm(self, val):
+    if val is not None and self.gradient_transformers:
+      raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
+                       "is set. Instead, use the `gradient_transformers` to "
+                       "specify clipping and other transformations.")
+    self._global_clipnorm = val
+    self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn(
+        self._global_clipnorm)
+
   @property
   def clipvalue(self):
     """`float` or `None`. If set, clips gradients to a maximum value."""
@@ -425,6 +448,8 @@
       grads_and_vars = self._clipvalue_fn(grads_and_vars)
     if self._clipnorm is not None:
       grads_and_vars = self._clipnorm_fn(grads_and_vars)
+    if self._global_clipnorm is not None:
+      grads_and_vars = self._global_clipnorm_fn(grads_and_vars)
 
     for fn in self.gradient_transformers:
       grads_and_vars = fn(grads_and_vars)
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
index e994a6e..65539fa 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -237,7 +237,7 @@
   @combinations.generate(combinations.combine(mode=['graph', 'eager']))
   def testComputeGradientsWithTensors(self):
     with testing_utils.use_gpu():
-      x = ops.convert_to_tensor_v2(1.0)
+      x = ops.convert_to_tensor_v2_with_dispatch(1.0)
 
       def f():
         return x * x
@@ -357,6 +357,22 @@
       self.assertAllClose([0.], self.evaluate(var))
 
   @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testGradGlobalClipNorm(self):
+    with testing_utils.use_gpu():
+      # l2 norm is 5.0
+      var1 = variables.Variable([1.0])
+      var2 = variables.Variable([2.0])
+      loss = lambda: 3 * var1 + 4 * var2
+      opt = gradient_descent.SGD(learning_rate=1.0, global_clipnorm=2.0)
+      opt_op = opt.minimize(loss, [var1, var2])
+      self.evaluate(variables.global_variables_initializer())
+      self.evaluate(opt_op)
+      # grad1 = 3.0 * 2.0 / 5.0 = 1.2
+      self.assertAllClose([-.2], self.evaluate(var1))
+      # grad2 = 4.0 * 2.0 / 5.0 = 1.6
+      self.assertAllClose([.4], self.evaluate(var2))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
   def testInvalidClipNorm(self):
     with self.assertRaisesRegex(ValueError, '>= 0'):
       gradient_descent.SGD(learning_rate=1.0, clipnorm=-1.0)
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
index 1fa2577..407dbf3 100644
--- a/tensorflow/python/keras/optimizer_v2/rmsprop.py
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -167,7 +167,8 @@
     apply_state[(var_device, var_dtype)].update(
         dict(
             neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"],
-            epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
+            epsilon=ops.convert_to_tensor_v2_with_dispatch(
+                self.epsilon, var_dtype),
             rho=rho,
             momentum=array_ops.identity(self._get_hyper("momentum", var_dtype)),
             one_minus_rho=1. - rho))
diff --git a/tensorflow/python/keras/optimizer_v2/utils.py b/tensorflow/python/keras/optimizer_v2/utils.py
index 909a25d..4495879 100644
--- a/tensorflow/python/keras/optimizer_v2/utils.py
+++ b/tensorflow/python/keras/optimizer_v2/utils.py
@@ -104,6 +104,26 @@
   return gradient_clipnorm_fn
 
 
+def make_global_gradient_clipnorm_fn(clipnorm):
+  """Creates a gradient transformation function for clipping by norm."""
+  if clipnorm is None:
+    return lambda grads_and_vars: grads_and_vars
+
+  def gradient_clipnorm_fn(grads_and_vars):
+
+    if isinstance(distribute_ctx.get_strategy(),
+                  central_storage_strategy.CentralStorageStrategy):
+      raise ValueError(
+          "`global_clipnorm` is not supported with `CenteralStorageStrategy`")
+
+    grads, variables = zip(*grads_and_vars)
+    clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm)
+    clipped_grads_and_vars = list(zip(clipped_grads, variables))
+    return clipped_grads_and_vars
+
+  return gradient_clipnorm_fn
+
+
 def make_gradient_clipvalue_fn(clipvalue):
   """Creates a gradient transformation function for clipping by value."""
   if clipvalue is None:
diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py
index f943967..40ec3d1 100644
--- a/tensorflow/python/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/preprocessing/image.py
@@ -111,7 +111,7 @@
   if len(size) != 2:
     raise ValueError('Expected `size` to be a tuple of 2 integers, '
                      'but got: %s' % (size,))
-  img = ops.convert_to_tensor(x)
+  img = ops.convert_to_tensor_v2_with_dispatch(x)
   if img.shape.rank is not None:
     if img.shape.rank != 3:
       raise ValueError(
diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD
index 62000be..2e1b6cd 100644
--- a/tensorflow/python/keras/saving/BUILD
+++ b/tensorflow/python/keras/saving/BUILD
@@ -177,7 +177,7 @@
     size = "medium",
     srcs = ["saved_model/revive_test.py"],
     python_version = "PY3",
-    shard_count = 4,
+    shard_count = 8,
     tags = [
         "no_windows",  # b/158005583
     ],
diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py
index 92296b5..1817bfc 100644
--- a/tensorflow/python/keras/saving/hdf5_format_test.py
+++ b/tensorflow/python/keras/saving/hdf5_format_test.py
@@ -1276,7 +1276,7 @@
       prefix = 'ackpt'
       self.evaluate(v.assign(42.))
       m.save_weights(prefix)
-      self.assertTrue(file_io.file_exists('ackpt.index'))
+      self.assertTrue(file_io.file_exists_v2('ackpt.index'))
       self.evaluate(v.assign(1.))
       m.load_weights(prefix)
       self.assertEqual(42., self.evaluate(v))
@@ -1284,7 +1284,7 @@
       prefix = 'subdir/ackpt'
       self.evaluate(v.assign(43.))
       m.save_weights(prefix)
-      self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
+      self.assertTrue(file_io.file_exists_v2('subdir/ackpt.index'))
       self.evaluate(v.assign(2.))
       m.load_weights(prefix)
       self.assertEqual(43., self.evaluate(v))
@@ -1292,7 +1292,7 @@
       prefix = 'ackpt/'
       self.evaluate(v.assign(44.))
       m.save_weights(prefix)
-      self.assertTrue(file_io.file_exists('ackpt/.index'))
+      self.assertTrue(file_io.file_exists_v2('ackpt/.index'))
       self.evaluate(v.assign(3.))
       m.load_weights(prefix)
       self.assertEqual(44., self.evaluate(v))
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index c016060..556675d 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -43,7 +43,6 @@
 from tensorflow.python.training.tracking.tracking import delete_tracking
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
-from tensorflow.python.util import object_identity
 
 # To avoid circular dependencies between keras/engine and keras/saving,
 # code in keras/saving must delay imports.
@@ -179,8 +178,6 @@
     # records all nodes that were generated directly/indirectly from the config,
     # so that they do not get recreated multiple times.
     self._nodes_recreated_from_config = {}
-    self._all_nodes_recreated_from_config = (
-        object_identity.ObjectIdentityWeakSet())
     # Store all node ids that have already been traversed when tracking nodes
     # that were recreated from the config.
     self._traversed_nodes_from_config = []
@@ -293,7 +290,6 @@
                      'Object: {}'.format(obj_child))
       self._nodes_recreated_from_config[child_id] = (
           obj_child, self._config_node_setter(setter))
-      self._all_nodes_recreated_from_config.add(obj_child)
       self._add_children_recreated_from_config(
           obj_child, child_proto, child_id)
 
@@ -363,7 +359,6 @@
 
     setter = self._config_node_setter(_revive_setter)
     self._nodes_recreated_from_config[node_id] = obj, setter
-    self._all_nodes_recreated_from_config.add(obj)
     self._add_children_recreated_from_config(
         obj, self._proto.nodes[node_id], node_id)
     return obj, setter
@@ -380,8 +375,11 @@
         metadata['class_name'] == 'Sequential' or
         metadata['class_name'] == 'Functional')
     if not (generic_utils.validate_config(config) and
-            model_is_functional_or_sequential):
-      return None  # Revive as custom model.
+            model_is_functional_or_sequential
+           ) or generic_utils.get_registered_object(class_name) is not None:
+      # Model should not be revived as a graph network. Try reviving directly
+      # from config or as a custom model.
+      return None
 
     # Revive functional and sequential models as blank model objects for now (
     # must be initialized to enable setattr tracking and attribute caching).
diff --git a/tensorflow/python/keras/saving/saved_model/revive_test.py b/tensorflow/python/keras/saving/saved_model/revive_test.py
index 5e94597..5c4f8a2 100644
--- a/tensorflow/python/keras/saving/saved_model/revive_test.py
+++ b/tensorflow/python/keras/saving/saved_model/revive_test.py
@@ -24,9 +24,9 @@
 from __future__ import division
 from __future__ import print_function
 
-import os
 import shutil
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python import keras
@@ -115,18 +115,43 @@
             'name': self.name}
 
 
-class TestModelRevive(keras_parameterized.TestCase):
+class CustomNetworkDefaultConfig(keras.Model):
+
+  def __init__(self, num_classes, name=None):
+    inputs = keras.Input((2, 3), name='inputs')
+    x = keras.layers.Flatten(name='flatten')(inputs)
+    y = keras.layers.Dense(num_classes, name='outputs')(x)
+    super(CustomNetworkDefaultConfig, self).__init__(inputs, y, name=name)
+
+
+class CustomNetworkWithConfig(CustomNetworkDefaultConfig):
+
+  def __init__(self, num_classes, name=None):
+    super(CustomNetworkWithConfig, self).__init__(num_classes, name=name)
+    self._config_dict = dict(num_classes=num_classes)
+
+  def get_config(self):
+    return self._config_dict
+
+  @classmethod
+  def from_config(cls, config):
+    return cls(config['num_classes'], name=config.get('name'))
+
+
+class CustomNetworkWithConfigName(CustomNetworkWithConfig):
+
+  def __init__(self, num_classes, name=None):
+    super(CustomNetworkWithConfigName, self).__init__(num_classes, name=name)
+    self._config_dict['name'] = self.name
+
+
+class ReviveTestBase(keras_parameterized.TestCase):
 
   def setUp(self):
-    super(TestModelRevive, self).setUp()
+    super(ReviveTestBase, self).setUp()
     self.path = self.get_temp_dir()
     self.addCleanup(shutil.rmtree, self.path, ignore_errors=True)
 
-  def _save_model_dir(self, dirname='saved_model'):
-    temp_dir = self.get_temp_dir()
-    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
-    return os.path.join(temp_dir, dirname)
-
   def _assert_revived_correctness(self, model, revived):
     self.assertAllEqual(model.input_names, revived.input_names)
     self.assertAllEqual(model.output_names, revived.output_names)
@@ -173,6 +198,11 @@
         self.assertEqual(type(model_layer).__name__,
                          type(revived_layer).__name__)
 
+
+# These tests take a while to run, so each should run in a separate shard
+# (putting them in the same TestCase resolves this).
+class TestBigModelRevive(ReviveTestBase):
+
   @keras_parameterized.run_with_all_model_types
   def test_revive(self):
     input_shape = None
@@ -235,6 +265,9 @@
     revived = keras_load.load(self.path)
     self._assert_revived_correctness(model, revived)
 
+
+class TestModelRevive(ReviveTestBase):
+
   def test_revive_subclassed_with_nested_model(self):
     model = SubclassedModelNoConfig(1., 2.)
     # Run data through the Model to create save spec and weights.
@@ -244,17 +277,31 @@
     self._assert_revived_correctness(model, revived)
 
   def test_revive_sequential_inputs(self):
-    model = keras.models.Sequential(
-        [keras.Input((None,), dtype=dtypes.string),
-         keras.layers.Lambda(string_ops.string_lower)])
+    model = keras.models.Sequential([
+        keras.Input((None,), dtype=dtypes.string),
+        keras.layers.Lambda(string_ops.string_lower)
+    ])
     model.save(self.path, save_format='tf')
     revived = keras_load.load(self.path)
     self.assertEqual(dtypes.string, revived._layers[0].dtype)
 
+  @parameterized.named_parameters(
+      ('default_config', CustomNetworkDefaultConfig),
+      ('with_config', CustomNetworkWithConfig),
+      ('with_config_name', CustomNetworkWithConfigName))
+  def test_revive_network(self, model_cls):
+    model = model_cls(8)
+    model.save(self.path, include_optimizer=False, save_format='tf')
+    revived = keras_load.load(self.path, compile=False)
+    self._assert_revived_correctness(model, revived)
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()
   with generic_utils.CustomObjectScope({
       'CustomLayerWithConfig': CustomLayerWithConfig,
-      'SubclassedModelWithConfig': SubclassedModelWithConfig}):
+      'CustomNetworkWithConfig': CustomNetworkWithConfig,
+      'CustomNetworkWithConfigName': CustomNetworkWithConfigName,
+      'SubclassedModelWithConfig': SubclassedModelWithConfig
+  }):
     test.main()
diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
index 1dff9a2..69115e0 100644
--- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
@@ -507,7 +507,7 @@
 
     self.assertAllClose(
         model.predict(input_arr),
-        loaded.signatures['predict'](ops.convert_to_tensor_v2(
+        loaded.signatures['predict'](ops.convert_to_tensor_v2_with_dispatch(
             input_arr.astype('float32')))['predictions'])
 
     feature = {
@@ -517,7 +517,7 @@
     example = example_pb2.Example(
         features=feature_pb2.Features(feature=feature))
     outputs = loaded.signatures['parse_and_predict'](
-        ops.convert_to_tensor_v2([example.SerializeToString()]))
+        ops.convert_to_tensor_v2_with_dispatch([example.SerializeToString()]))
     self.assertAllClose(model.predict(input_arr), outputs['predictions'])
     self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs'])
 
diff --git a/tensorflow/python/keras/saving/saved_model_experimental.py b/tensorflow/python/keras/saving/saved_model_experimental.py
index 25628cd..6476518 100644
--- a/tensorflow/python/keras/saving/saved_model_experimental.py
+++ b/tensorflow/python/keras/saving/saved_model_experimental.py
@@ -30,8 +30,8 @@
 from tensorflow.python.keras.saving import saving_utils
 from tensorflow.python.keras.utils import mode_keys
 from tensorflow.python.keras.utils.generic_utils import LazyLoader
-from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import builder as saved_model_builder
 from tensorflow.python.saved_model import constants
@@ -152,7 +152,8 @@
   model_json_filepath = os.path.join(
       saved_model_utils.get_or_create_assets_dir(saved_model_path),
       compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
-  file_io.write_string_to_file(model_json_filepath, model_json)
+  with gfile.Open(model_json_filepath, 'w') as f:
+    f.write(model_json)
 
 
 def _export_model_variables(model, saved_model_path):
@@ -417,7 +418,8 @@
       compat.as_bytes(saved_model_path),
       compat.as_bytes(constants.ASSETS_DIRECTORY),
       compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
-  model_json = file_io.read_file_to_string(model_json_filepath)
+  with gfile.Open(model_json_filepath, 'r') as f:
+    model_json = f.read()
   model = model_config.model_from_json(
       model_json, custom_objects=custom_objects)
 
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 550ff66..96868d0 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -26,6 +26,7 @@
 
 from tensorflow.python import tf2
 from tensorflow.python.eager import context
+from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
@@ -937,3 +938,69 @@
   """Uses gpu when requested and available."""
   with device(should_use_gpu=True):
     yield
+
+
+def for_all_test_methods(decorator, *args, **kwargs):
+  """Generate class-level decorator from given method-level decorator.
+
+  It is expected for the given decorator to take some arguments and return
+  a method that is then called on the test method to produce a decorated
+  method.
+
+  Args:
+    decorator: The decorator to apply.
+    *args: Positional arguments
+    **kwargs: Keyword arguments
+  Returns: Function that will decorate a given classes test methods with the
+    decorator.
+  """
+
+  def all_test_methods_impl(cls):
+    """Apply decorator to all test methods in class."""
+    for name in dir(cls):
+      value = getattr(cls, name)
+      if callable(value) and name.startswith('test') and (name !=
+                                                          'test_session'):
+        setattr(cls, name, decorator(*args, **kwargs)(value))
+    return cls
+
+  return all_test_methods_impl
+
+
+# The description is just for documentation purposes.
+def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute test with TensorFloat-32 disabled.
+
+  While almost every real-world deep learning model runs fine with
+  TensorFloat-32, many tests use assertAllClose or similar methods.
+  TensorFloat-32 matmuls typically will cause such methods to fail with the
+  default tolerances.
+
+  Args:
+    description: A description used for documentation purposes, describing why
+      the test requires TensorFloat-32 to be disabled.
+
+  Returns:
+    Decorator which runs a test with TensorFloat-32 disabled.
+  """
+
+  def decorator(f):
+
+    @functools.wraps(f)
+    def decorated(self, *args, **kwargs):
+      allowed = config.tensor_float_32_execution_enabled()
+      try:
+        config.enable_tensor_float_32_execution(False)
+        f(self, *args, **kwargs)
+      finally:
+        config.enable_tensor_float_32_execution(allowed)
+
+    return decorated
+
+  return decorator
+
+
+# The description is just for documentation purposes.
+def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
+  """Execute all tests in a class with TensorFloat-32 disabled."""
+  return for_all_test_methods(run_without_tensor_float_32, description)
diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD
index 4db3327..9ce20c2 100644
--- a/tensorflow/python/keras/tests/BUILD
+++ b/tensorflow/python/keras/tests/BUILD
@@ -274,7 +274,6 @@
     name = "op_callbacks_test",
     srcs = ["op_callbacks_test.py"],
     python_version = "PY3",
-    tags = ["no_cuda11"],
     xla_enable_strict_auto_jit = False,
     deps = [
         "//tensorflow/python:framework_ops",
diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py
index 4b9b625..f47ee62 100644
--- a/tensorflow/python/keras/tests/model_subclassing_test.py
+++ b/tensorflow/python/keras/tests/model_subclassing_test.py
@@ -428,7 +428,7 @@
       def call(self, inputs):
         return inputs + self.b + self.c
 
-    x = ops.convert_to_tensor_v2(np.ones((10, 10), 'float32'))
+    x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10), 'float32'))
     model = MyModel()
     model(x)
     self.assertEqual(1, len(model.trainable_weights))
@@ -444,7 +444,7 @@
       def call(self, inputs):
         return inputs + self.b + self.c
 
-    x = ops.convert_to_tensor_v2(np.ones((10, 10), 'float32'))
+    x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10), 'float32'))
     model = MyModelCustomBuild()
     model(x)
     self.assertEqual(1, len(model.trainable_weights))
@@ -467,7 +467,7 @@
         self.add_update(self.c.assign(inputs[1, :]))
         return inputs + self.b + self.c
 
-    x = ops.convert_to_tensor_v2(np.ones((10, 10), 'float32'))
+    x = ops.convert_to_tensor_v2_with_dispatch(np.ones((10, 10), 'float32'))
     model = MyModel()
     model(x)
 
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index 3195bb0..7959b02 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -29,7 +29,6 @@
 from tensorflow.python.keras.utils.conv_utils import convert_kernel
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
-from tensorflow.python.util import object_identity
 from tensorflow.python.util.tf_export import keras_export
 
 
@@ -104,7 +103,7 @@
   Returns:
       The total number of scalars composing the weights
   """
-  unique_weights = object_identity.ObjectIdentitySet(weights)
+  unique_weights = {id(w): w for w in weights}.values()
   weight_shapes = [w.shape.as_list() for w in unique_weights]
   standardized_weight_shapes = [
       [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
@@ -502,4 +501,3 @@
 
   wrapped.cache = cache
   return wrapped
-
diff --git a/tensorflow/python/keras/utils/losses_utils.py b/tensorflow/python/keras/utils/losses_utils.py
index b8a063e..08ef613 100644
--- a/tensorflow/python/keras/utils/losses_utils.py
+++ b/tensorflow/python/keras/utils/losses_utils.py
@@ -253,11 +253,11 @@
     ops.get_default_graph()._last_loss_reduction = reduction  # pylint: disable=protected-access
 
     if not isinstance(losses, keras_tensor.KerasTensor):
-      losses = ops.convert_to_tensor_v2(losses)
+      losses = ops.convert_to_tensor_v2_with_dispatch(losses)
     input_dtype = losses.dtype
 
     if not isinstance(sample_weight, keras_tensor.KerasTensor):
-      sample_weight = ops.convert_to_tensor_v2(sample_weight)
+      sample_weight = ops.convert_to_tensor_v2_with_dispatch(sample_weight)
 
     # TODO(psv): Handle casting here in a better way, eg. if losses is float64
     # we do not want to lose precision.
diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py
index 7d47850..5b3905a 100644
--- a/tensorflow/python/keras/utils/metrics_utils.py
+++ b/tensorflow/python/keras/utils/metrics_utils.py
@@ -311,7 +311,8 @@
 
   y_true = math_ops.cast(y_true, dtype=variable_dtype)
   y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
-  thresholds = ops.convert_to_tensor_v2(thresholds, dtype=variable_dtype)
+  thresholds = ops.convert_to_tensor_v2_with_dispatch(
+      thresholds, dtype=variable_dtype)
   num_thresholds = thresholds.shape[0]
   if multi_label:
     one_thresh = math_ops.equal(
diff --git a/tensorflow/python/keras/utils/tf_utils_test.py b/tensorflow/python/keras/utils/tf_utils_test.py
index 9a3939e..73d8671 100644
--- a/tensorflow/python/keras/utils/tf_utils_test.py
+++ b/tensorflow/python/keras/utils/tf_utils_test.py
@@ -44,14 +44,17 @@
       self.assertFalse(tf_utils.is_symbolic_tensor(
           variables.Variable(name='blah', initial_value=0.)))
       self.assertFalse(
-          tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
+          tf_utils.is_symbolic_tensor(
+              ops.convert_to_tensor_v2_with_dispatch(0.)))
       self.assertFalse(tf_utils.is_symbolic_tensor(
           sparse_tensor.SparseTensor(
               indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
     else:
       self.assertTrue(tf_utils.is_symbolic_tensor(
           variables.Variable(name='blah', initial_value=0.)))
-      self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
+      self.assertTrue(
+          tf_utils.is_symbolic_tensor(
+              ops.convert_to_tensor_v2_with_dispatch(0.)))
       self.assertTrue(tf_utils.is_symbolic_tensor(
           sparse_tensor.SparseTensor(
               indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
@@ -61,7 +64,7 @@
     class CustomClass(object):
 
       def value(self):
-        return ops.convert_to_tensor_v2(42.)
+        return ops.convert_to_tensor_v2_with_dispatch(42.)
 
     ops.register_tensor_conversion_function(
         CustomClass, lambda value, **_: value.value())
@@ -72,7 +75,8 @@
       self.assertFalse(tf_utils.is_symbolic_tensor(
           variables.Variable(name='blah', initial_value=0.)))
       self.assertFalse(
-          tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
+          tf_utils.is_symbolic_tensor(
+              ops.convert_to_tensor_v2_with_dispatch(0.)))
       self.assertFalse(tf_utils.is_symbolic_tensor(
           sparse_tensor.SparseTensor(
               indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
@@ -80,7 +84,9 @@
     else:
       self.assertTrue(tf_utils.is_symbolic_tensor(
           variables.Variable(name='blah', initial_value=0.)))
-      self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
+      self.assertTrue(
+          tf_utils.is_symbolic_tensor(
+              ops.convert_to_tensor_v2_with_dispatch(0.)))
       self.assertTrue(tf_utils.is_symbolic_tensor(
           sparse_tensor.SparseTensor(
               indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
@@ -95,7 +101,7 @@
 
       def __init__(self, input_):
         self._input = input_
-        self.value = ops.convert_to_tensor_v2([[42.]])
+        self.value = ops.convert_to_tensor_v2_with_dispatch([[42.]])
 
       @property
       def dtype(self):
@@ -110,7 +116,7 @@
       def __init__(self, fn, **kwargs):
         def _fn(*fargs, **fkwargs):
           d = fn(*fargs, **fkwargs)
-          x = ops.convert_to_tensor_v2(d)
+          x = ops.convert_to_tensor_v2_with_dispatch(d)
           d.shape = x.shape
           d.get_shape = x.get_shape
           return d, x
@@ -138,7 +144,7 @@
     model = keras.Model(model.inputs, model(model.outputs))
     # Now we instantiate the model and verify we have a `Foo` object, not a
     # `Tensor`.
-    y = model(ops.convert_to_tensor_v2([[7.]]))
+    y = model(ops.convert_to_tensor_v2_with_dispatch([[7.]]))
     self.assertIsInstance(y, Foo)
     # Confirm that (custom) loss sees `Foo` instance, not Tensor.
     obtained_prediction_box = [None]
diff --git a/tensorflow/python/keras/utils/vis_utils_test.py b/tensorflow/python/keras/utils/vis_utils_test.py
index 8b401c3..ccdde30 100644
--- a/tensorflow/python/keras/utils/vis_utils_test.py
+++ b/tensorflow/python/keras/utils/vis_utils_test.py
@@ -38,8 +38,8 @@
     try:
       vis_utils.plot_model(
           model, to_file=dot_img_file, show_shapes=True, show_dtype=True)
-      self.assertTrue(file_io.file_exists(dot_img_file))
-      file_io.delete_file(dot_img_file)
+      self.assertTrue(file_io.file_exists_v2(dot_img_file))
+      file_io.delete_file_v2(dot_img_file)
     except ImportError:
       pass
 
@@ -68,8 +68,8 @@
           show_shapes=True,
           show_dtype=True,
           expand_nested=True)
-      self.assertTrue(file_io.file_exists(dot_img_file))
-      file_io.delete_file(dot_img_file)
+      self.assertTrue(file_io.file_exists_v2(dot_img_file))
+      file_io.delete_file_v2(dot_img_file)
     except ImportError:
       pass
 
@@ -86,8 +86,8 @@
           show_shapes=True,
           show_dtype=True,
           expand_nested=True)
-      self.assertTrue(file_io.file_exists(dot_img_file))
-      file_io.delete_file(dot_img_file)
+      self.assertTrue(file_io.file_exists_v2(dot_img_file))
+      file_io.delete_file_v2(dot_img_file)
     except ImportError:
       pass
 
@@ -102,8 +102,8 @@
           show_shapes=True,
           show_dtype=True,
           expand_nested=True)
-      self.assertTrue(file_io.file_exists(dot_img_file))
-      file_io.delete_file(dot_img_file)
+      self.assertTrue(file_io.file_exists_v2(dot_img_file))
+      file_io.delete_file_v2(dot_img_file)
     except ImportError:
       pass
 
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 0d6b6ac..2df4860 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2048,6 +2048,9 @@
     name = "dynamic_partition_op_test",
     size = "medium",
     srcs = ["dynamic_partition_op_test.py"],
+    tags = [
+        "multi_and_single_gpu",
+    ],
     tfrt_enabled = True,
     deps = [
         "//tensorflow/python:array_ops",
diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
index 30b6102..ac82a32 100644
--- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
@@ -130,6 +130,7 @@
 
 def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
 
+  @test_util.run_without_tensor_float_32("Tests batch matmul")
   def Test(self):
     np.random.seed(42)
     self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape)
@@ -141,6 +142,7 @@
 def _GetBatchMatmulOpBroadcastingTest(dtype, adjoint_a, adjoint_b,
                                       use_static_shape):
 
+  @test_util.run_without_tensor_float_32("Tests batch matmul")
   def Test(self):
     np.random.seed(42)
     self._testBroadcasting(dtype, adjoint_a, adjoint_b, use_static_shape)
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index a9afca8..0697f7d 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -106,7 +106,7 @@
   def _verifyCholesky(self, x):
     # Verify that LL^T == x.
     chol = linalg_ops.cholesky(x)
-    verification = math_ops.matmul(chol, chol, adjoint_b=True)
+    verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True)
     self._verifyCholeskyBase(x, chol, verification)
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
@@ -271,8 +271,8 @@
     def Compute(x):
       # Turn the random matrix x into a Hermitian matrix by
       # computing the quadratic form x * x^H.
-      a = math_ops.matmul(x, math_ops.conj(
-          array_ops.matrix_transpose(x))) / shape[0]
+      a = test_util.matmul_without_tf32(
+          x, math_ops.conj(array_ops.matrix_transpose(x))) / shape[0]
       if batch:
         a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1])
       # Finally take the cholesky decomposition of the Hermitian matrix.
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 4c9fbd5..30b20c6 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -20,7 +20,6 @@
 from __future__ import print_function
 
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.compat.compat import forward_compatibility_horizon
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
@@ -1607,5 +1606,4 @@
 
 
 if __name__ == "__main__":
-  with forward_compatibility_horizon(2020, 8, 21):
-    test.main()
+  test.main()
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index ff4da3a..3acc1fe 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -48,6 +48,10 @@
   return test_configs
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Tests Conv3d, which in some cases is implemented with a matmul. With "
+    "tf32, tests fail in some of those cases (and as of August 13 2020, only "
+    "those cases)")
 class Conv3DTest(test.TestCase):
 
   def _DtypesToTest(self, use_gpu):
@@ -189,9 +193,9 @@
                 e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6)
 
   def _CreateNumpyTensor(self, sizes):
-    return np.asarray([f * 1.0
-                       for f in range(1,
-                                      np.prod(sizes) + 1)]).reshape(sizes)
+    return np.asarray([f * 1.0 for f in range(1,
+                                              np.prod(sizes) + 1)],
+                      dtype=np.float32).reshape(sizes)
 
   @test_util.run_in_graph_and_eager_modes
   def testConv3DExpandedBatch(self):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index f480f43..f723439 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -310,14 +310,14 @@
           data_format, use_gpu)
       expected_results.append(expected)
       computed_results.append(computed)
-      tolerance = 1e-2 if use_gpu else 1e-5
-      expected_values = self.evaluate(expected_results)
-      computed_values = self.evaluate(computed_results)
-      for e_value, c_value in zip(expected_values, computed_values):
-        tf_logging.debug("expected = %s", e_value)
-        tf_logging.debug("actual = %s", c_value)
-        self.assertAllClose(
-            e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=rtol)
+    tolerance = 1e-2 if use_gpu else 1e-5
+    expected_values = self.evaluate(expected_results)
+    computed_values = self.evaluate(computed_results)
+    for e_value, c_value in zip(expected_values, computed_values):
+      tf_logging.debug("expected = %s", e_value)
+      tf_logging.debug("actual = %s", c_value)
+      self.assertAllClose(
+          e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=rtol)
 
   def _VerifyValues(self,
                     tensor_in_sizes,
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index a7d8f84..96628a1 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -1032,12 +1032,10 @@
     self._compare_values(x, y=y)
 
   def testTypes(self):
-    # TODO(b/131162241): Enable test for GPU
-    with ops.device("/CPU:0"):
-      for dtype in [np.float16, np.float32, np.float64,
-                    dtypes_lib.bfloat16.as_numpy_dtype]:
-        with self.subTest(dtype=dtype):
-          self._testDtype(dtype)
+    for dtype in [np.float16, np.float32, np.float64,
+                  dtypes_lib.bfloat16.as_numpy_dtype]:
+      with self.subTest(dtype=dtype):
+        self._testDtype(dtype)
 
 
 class ComplexMakeRealImagTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
index a4c07da..7c8f389 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
@@ -268,6 +268,8 @@
       self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.)
       self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
 
+  @test_util.run_without_tensor_float_32(
+      "Tests DirichletMultinomial.covariance, which calls matmul")
   def testCovariance(self):
     # Shape [2]
     alpha = [1., 2]
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index 0f96382..a0d8bef 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -200,6 +200,8 @@
     self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.)
     self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
 
+  @test_util.run_without_tensor_float_32(
+      "Calls Dirichlet.covariance, which calls matmul")
   def testVariance(self):
     alpha = [1., 2, 3]
     denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 8c44819..0fd9790 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -23,8 +23,10 @@
 import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
+from tensorflow.python.framework import config
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import data_flow_ops
@@ -346,6 +348,19 @@
       res = self.evaluate(partitioned)
     self.assertEqual(res[-1].shape[0], 192)
 
+  #  see https://github.com/tensorflow/tensorflow/issues/42500
+  def testMultiGPU(self):
+    device_list = config.list_logical_devices("GPU")
+    results = []
+    for device in device_list:
+      with ops.device(device.name):
+        data = constant_op.constant(np.zeros((1000,)))
+        partitions = constant_op.constant(np.arange(1000, dtype=np.int32) % 10)
+        result = data_flow_ops.dynamic_partition(data, partitions, 10)
+        results.append(self.evaluate(result))
+    if device_list:
+      self.assertAllEqual(results, np.zeros((len(device_list), 10, 100)))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py
index 10b9671..4236eb9 100644
--- a/tensorflow/python/kernel_tests/einsum_op_test.py
+++ b/tensorflow/python/kernel_tests/einsum_op_test.py
@@ -35,6 +35,8 @@
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_without_tensor_float_32(
+    'Tests einsum, which sometimes does a matmul with cuBLAS')
 class EinsumOpTest(test.TestCase):
 
   def _check(self, s, *input_shapes, **kwargs):
@@ -237,7 +239,6 @@
           ((4, 3), (None, 3)))
     check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
 
-  @test_util.disable_xla('b/131919749')
   def testOutputRepeatedLabels(self):
     # This is the reverse operation of generalized traces, to be used for
     # computing symbolic gradients of einsum. Note: this operation is not
@@ -264,7 +265,6 @@
     # From transformer xl.
     check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10))
 
-  @test_util.disable_xla('b/131919749')
   def testEmptyWithRepeatedLabels(self):
 
     def check(equation, input_shapes, output_shape):
@@ -287,6 +287,8 @@
 
 
 @test_util.run_all_in_graph_and_eager_modes
+@test_util.run_all_without_tensor_float_32(
+    "Tests einsum's gradient, which sometimes does a matmul with cuBLAS")
 class EinsumGradTest(test.TestCase):
 
   def _check_gradient(self, s, *input_shapes):
@@ -310,7 +312,6 @@
           self.assertLess(
               gradient_checker_v2.max_error(analytical, numerical), tol)
 
-  @test_util.disable_xla('b/131919749')
   def testUnary(self):
     # Unary cases.
     self._check_gradient('->', ())
@@ -319,7 +320,6 @@
     self._check_gradient('aabcd->add', (3, 3, 5, 4, 4))
     self._check_gradient('abcd->da', (3, 5, 4, 2))
 
-  @test_util.disable_xla('b/131919749')
   def testUnaryEllipsis(self):
     self._check_gradient('...->...', ())
     self._check_gradient('...->', ())
@@ -362,11 +362,9 @@
     self._check_gradient('ijkm,ijln->ijmn', (2, 3, 3, 4), (2, 3, 3, 2))
     self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3))
 
-  @test_util.disable_xla('b/131919749')
   def testReducedIndicesWithRepeatedLabels(self):
     self._check_gradient('abce,badf->bcba', (1, 2, 3, 4), (2, 1, 4, 3))
 
-  @test_util.disable_xla('b/131919749')
   def testRepeatedLabels(self):
     # Repeated indices.
     self._check_gradient('aba,a->b', (3, 4, 3), (3,))
@@ -376,7 +374,6 @@
     self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4))
     self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
 
-  @test_util.disable_xla('b/131919749')
   def testEmptyWithRepeatedLabels(self):
     self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10))
     self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10))
@@ -388,7 +385,6 @@
     self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4))
     self._check_gradient('i...j,j...k->i...k', (3, 1, 2, 2), (2, 2, 3, 1, 4))
 
-  @test_util.disable_xla('b/131919749')
   def testBroadcastingWithRepeatedLabels(self):
     self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
     self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4))
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index e3268fa..f2348c6 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -945,6 +945,8 @@
       self.assertAllClose(abs_value, count, rtol=tol, atol=tol)
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Tests convolutional_orthogonal_1d, which calls matmul")
 class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
 
   @test_util.run_deprecated_v1
@@ -1174,6 +1176,8 @@
         self.assertAllClose(self.evaluate(ratio), gain, rtol=tol, atol=tol)
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Tests convolutional_orthogonal_3d, which calls matmul")
 class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
 
   @test_util.run_deprecated_v1
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py
index ac82f19..f42600b 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py
@@ -534,7 +534,7 @@
       c_value = self.evaluate(c)
 
       expected_c_value = self.evaluate(
-          math_ops.conj(math_ops.matmul(a_dense, b)))
+          math_ops.conj(test_util.matmul_without_tf32(a_dense, b)))
       self.assertAllClose(expected_c_value, c_value)
 
   @test_util.run_in_graph_and_eager_modes
@@ -576,7 +576,7 @@
                 transpose_b=transpose_b,
                 adjoint_a=adjoint_a,
                 adjoint_b=adjoint_b)
-            c_dense_t = math_ops.matmul(
+            c_dense_t = test_util.matmul_without_tf32(
                 a_mats,
                 b_mats,
                 transpose_a=transpose_a,
@@ -640,7 +640,7 @@
                 adjoint_b=adjoint_b)
 
             # Example: t(adj(a) . b) = t(b) . conj(a)
-            c_dense_t = math_ops.matmul(
+            c_dense_t = test_util.matmul_without_tf32(
                 math_ops.conj(b_mats) if adjoint_b else b_mats,
                 math_ops.conj(a_mats) if adjoint_a else a_mats,
                 transpose_a=not (transpose_b or adjoint_b),
@@ -670,7 +670,7 @@
     c_t = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
         a_sm, b_mats, conjugate_output=True)
 
-    c_dense_t = math_ops.conj(math_ops.matmul(a_mats, b_mats))
+    c_dense_t = math_ops.conj(test_util.matmul_without_tf32(a_mats, b_mats))
     self.assertAllEqual(c_t.shape, c_dense_t.shape)
     c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))
 
@@ -772,7 +772,7 @@
             adjoint_b=adjoint_b)
         c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
             c_sm, dtypes.float32)
-        c_dense_t = math_ops.matmul(
+        c_dense_t = test_util.matmul_without_tf32(
             a_mats,
             b_mats,
             transpose_a=transpose_a,
@@ -1143,7 +1143,7 @@
         dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
             cholesky_sparse_matrices, dtype)
         # Compute L * Lh where L is the Sparse Cholesky factor.
-        verification = math_ops.matmul(
+        verification = test_util.matmul_without_tf32(
             dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True))
         verification = twist_matrix(verification, ordering_amd)
         # Assert that input matrix A satisfies A = L * Lh.
@@ -1197,7 +1197,7 @@
           cholesky_sparse_matrix, dtype)
 
       # Compute L * Lh.
-      verification = math_ops.matmul(
+      verification = test_util.matmul_without_tf32(
           dense_cholesky,
           array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True))
       verification = twist_matrix(verification, ordering_amd)
@@ -1238,7 +1238,7 @@
         cholesky_sparse_matrix, dtypes.float32)
 
     # Compute L * Lh.
-    verification = math_ops.matmul(
+    verification = test_util.matmul_without_tf32(
         dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1]))
     verification = twist_matrix(verification, ordering_amd)
     verification_values = self.evaluate(verification)
diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py
index 35c706c..4aa3474 100644
--- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py
+++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_test.py
@@ -162,7 +162,7 @@
                          1.j * np.random.randn(*dense_shape_b))).astype(dtype)
       a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
       b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats)
-      c_dense = math_ops.matmul(
+      c_dense = test_util.matmul_without_tf32(
           a_mats,
           b_mats,
           transpose_a=transpose_a,
@@ -202,7 +202,7 @@
       b_mats = (np.random.randn(*dense_shape_b) +
                 1.j * np.random.randn(*dense_shape_b)).astype(dtype)
       a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
-      c_dense = math_ops.matmul(
+      c_dense = test_util.matmul_without_tf32(
           a_mats,
           b_mats,
           transpose_a=transpose_a,
@@ -240,7 +240,7 @@
       b_mats = sparsify((np.random.randn(*dense_shape_b) +
                          1.j * np.random.randn(*dense_shape_b))).astype(dtype)
       b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats)
-      c_dense = math_ops.matmul(
+      c_dense = test_util.matmul_without_tf32(
           a_mats,
           b_mats,
           transpose_a=transpose_a,
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index f1d885f..273aba4 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -63,6 +63,9 @@
 
   @test_util.enable_control_flow_v2
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32(
+      'Tests `tf.linalg.expm`, which call matmul. Additionally, calls ops '
+      'which do matmul in their gradient, such as MatrixSolve.')
   def Test(self):
 
     def RandomInput():
@@ -102,6 +105,16 @@
                                         **kwargs_):
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32(
+      'Tests `tf.linalg.lstsq`, which call matmul. Additionally, calls ops '
+      'which do matmul in their gradient, such as MatrixSolveLs.')
+  # TODO(b/164254522): With tf32, some tests fails with extremely high absolute
+  # and relative differences when calling assertAllClose. For example, the test
+  # test_MatrixSolveLsGradient_float32_10_10_1e-06 of class
+  # MatrixBinaryFunctorGradientTest fails with a max absolute difference of
+  # 0.883 and a max relative difference of 736892. We should consider disabling
+  # tf32 within `tf.linalg.lstsq and perhaps other linear algebra functions,
+  # even if tf32 is allowed globally.
   def Test(self):
 
     def RandomInput():
diff --git a/tensorflow/python/kernel_tests/lu_op_test.py b/tensorflow/python/kernel_tests/lu_op_test.py
index fee6aec..8d522e8 100644
--- a/tensorflow/python/kernel_tests/lu_op_test.py
+++ b/tensorflow/python/kernel_tests/lu_op_test.py
@@ -91,7 +91,7 @@
     # Prepare the upper factor.
     upper = array_ops.matrix_band_part(lu, 0, -1)
 
-    verification = math_ops.matmul(lower, upper)
+    verification = test_util.matmul_without_tf32(lower, upper)
 
     # Permute the rows of product of the Cholesky factors.
     if num_rows > 0:
diff --git a/tensorflow/python/kernel_tests/map_ops_test.py b/tensorflow/python/kernel_tests/map_ops_test.py
index 8db5cd7..771e22e 100644
--- a/tensorflow/python/kernel_tests/map_ops_test.py
+++ b/tensorflow/python/kernel_tests/map_ops_test.py
@@ -26,6 +26,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import map_ops
+from tensorflow.python.ops import sort_ops
 from tensorflow.python.platform import test
 
 
@@ -57,7 +58,7 @@
     m = map_ops.empty_tensor_map()
     k = constant_op.constant(1.0)
     with self.assertRaisesRegex(errors.InvalidArgumentError,
-                                "Trying to lookup non-existent key."):
+                                "Trying to lookup non-existent key. *"):
       l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
       self.evaluate(l)
 
@@ -68,7 +69,7 @@
     v = constant_op.constant(11.0)
     m = map_ops.tensor_map_insert(m, k, v)
     with self.assertRaisesRegex(errors.InvalidArgumentError,
-                                "Trying to lookup non-existent key."):
+                                "Trying to lookup non-existent key. *"):
       l = map_ops.tensor_map_lookup(m, k2, dtypes.float32)
       self.evaluate(l)
 
@@ -87,7 +88,7 @@
     m = map_ops.empty_tensor_map()
     k = constant_op.constant(1.0)
     with self.assertRaisesRegex(errors.InvalidArgumentError,
-                                "Trying to erase non-existent item."):
+                                "Trying to erase non-existent item. *"):
       m = map_ops.tensor_map_erase(m, k, dtypes.float32)
       self.evaluate(m)
 
@@ -98,7 +99,7 @@
     v = constant_op.constant(2.0)
     m = map_ops.tensor_map_insert(m, k2, v)
     with self.assertRaisesRegex(errors.InvalidArgumentError,
-                                "Trying to erase non-existent item."):
+                                "Trying to erase non-existent item. *"):
       m = map_ops.tensor_map_erase(m, k, dtypes.float32)
       self.evaluate(m)
 
@@ -133,6 +134,58 @@
     self.assertAllClose(l, v)
     self.assertAllClose(l2, default_value)
 
+  def testStackKeys(self):
+    m = map_ops.empty_tensor_map()
+    k = constant_op.constant(1.0)
+    k2 = constant_op.constant(2.0)
+    k3 = constant_op.constant(3.0)
+    v = constant_op.constant(21.0)
+    v2 = constant_op.constant(22.0)
+    v3 = constant_op.constant(23.0)
+    m = map_ops.tensor_map_insert(m, k, v)
+    m = map_ops.tensor_map_insert(m, k2, v2)
+    keys = map_ops.tensor_map_stack_keys(m, k.dtype)
+    expected = constant_op.constant([1.0, 2.0])
+    self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
+    self.assertAllClose(sort_ops.sort(keys), expected)
+
+    m = map_ops.tensor_map_insert(m, k3, v3)
+    keys = map_ops.tensor_map_stack_keys(m, k.dtype)
+    expected = constant_op.constant([1.0, 2.0, 3.0])
+    self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
+    self.assertAllClose(sort_ops.sort(keys), expected)
+
+  def testStackKeysEmptyMapFails(self):
+    m = map_ops.empty_tensor_map()
+    with self.assertRaisesRegex(
+        errors.InvalidArgumentError, "TensorMapStackKeys cannot be called "
+        "on empty map."):
+      keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
+      self.evaluate(keys)
+
+  def testStackKeysIncorrectDtypeFails(self):
+    m = map_ops.empty_tensor_map()
+    k = constant_op.constant("key_with_wrong_dtype")
+    v = constant_op.constant(2.0)
+    m = map_ops.tensor_map_insert(m, k, v)
+    simple = "Key does not match requested dtype."
+    with self.assertRaisesRegex(errors.InvalidArgumentError, simple):
+      keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
+      self.evaluate(keys)
+
+  def testStackKeysIncorrectShapeFails(self):
+    m = map_ops.empty_tensor_map()
+    k = constant_op.constant(1.0)
+    k2 = constant_op.constant([1.0, 11.0])
+    v = constant_op.constant(2.0)
+    v2 = constant_op.constant(22.0)
+    m = map_ops.tensor_map_insert(m, k, v)
+    m = map_ops.tensor_map_insert(m, k2, v2)
+    with self.assertRaisesRegex(errors.InvalidArgumentError,
+                                "Keys must all have the same shape."):
+      keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
+      self.evaluate(keys)
+
   def testInsertLookupGrad(self):
     with backprop.GradientTape() as tape:
       m = map_ops.empty_tensor_map()
@@ -397,6 +450,5 @@
     self.assertAllEqual(s, 0)
     self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False)
 
-
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index 712d733..737ca77 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -70,6 +70,7 @@
 
 def _GetMatMulTest(a_np_, b_np_, use_static_shape_, **kwargs_):
 
+  @test_util.run_without_tensor_float_32("Tests matmul")
   def Test(self):
     np_val = np.matrix(a_np_) * np.matrix(b_np_)
 
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index ffe0f59..9a5a467 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -26,7 +26,6 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import benchmark
@@ -41,7 +40,7 @@
       with self.cached_session(use_gpu=True):
         # Verify that x^{-1} * x == Identity matrix.
         inv = linalg_ops.matrix_inverse(y, adjoint=adjoint)
-        tf_ans = math_ops.matmul(inv, y, adjoint_b=adjoint)
+        tf_ans = test_util.matmul_without_tf32(inv, y, adjoint_b=adjoint)
         np_ans = np.identity(y.shape[-1])
         if x.ndim > 2:
           tiling = list(y.shape)
diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
index 6cf330e..98796f2 100644
--- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
@@ -29,6 +29,7 @@
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_without_tensor_float_32
 class SquareRootOpTest(test.TestCase):
 
   def _verifySquareRoot(self, matrix, np_type):
@@ -36,7 +37,7 @@
 
     # Verify that matmul(sqrtm(A), sqrtm(A)) = A
     sqrt = gen_linalg_ops.matrix_square_root(matrix)
-    square = math_ops.matmul(sqrt, sqrt)
+    square = test_util.matmul_without_tf32(sqrt, sqrt)
     self.assertShapeEqual(matrix, square)
     self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3)
 
diff --git a/tensorflow/python/kernel_tests/parse_single_example_op_test.py b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
index ab270bf..498b6a8 100644
--- a/tensorflow/python/kernel_tests/parse_single_example_op_test.py
+++ b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
@@ -856,6 +856,7 @@
                                                  expected_err[1]):
           out = parsing_ops.parse_single_example(**kwargs)
           sess.run(flatten_values_tensors_or_sparse(out.values()))
+        return
       else:
         # Returns dict w/ Tensors and SparseTensors.
         out = parsing_ops.parse_single_example(**kwargs)
@@ -939,6 +940,20 @@
         },
         expected_output)
 
+  def testExampleLongerThanSpec(self):
+    serialized = example(
+        features=features({
+            "a": bytes_feature([b"a", b"b"]),
+        })).SerializeToString()
+    self._test(
+        {
+            "serialized": ops.convert_to_tensor(serialized),
+            "features": {
+                "a": parsing_ops.FixedLenFeature(1, dtypes.string)
+            }
+        },
+        expected_err=(errors_impl.OpError, "Can't parse serialized Example"))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index b895fe4..0a618b7 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -200,6 +200,8 @@
     return a
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32("Tests Qr gradient, which calls matmul"
+                                        )
   def Test(self):
     np.random.seed(42)
     # Optimal stepsize for central difference is O(epsilon^{1/3}).
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index c361f79..135e440 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -336,8 +336,6 @@
       self.assertLess(error.max(), 5 * std)
 
   # Check that minval = maxval is fine iff we're producing no numbers
-  @test_util.disable_tfrt(
-      "TFE_TensorHandleToNumpy not implemented yet. b/156191611")
   def testUniformIntsDegenerate(self):
     for dt in dtypes.int32, dtypes.int64:
       def sample(n):
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 9a927b8..83fdfc7 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -169,10 +169,12 @@
   @test_util.run_in_graph_and_eager_modes
   def testVariableShape(self):
     v = resource_variable_ops.ResourceVariable([1., 1.])
+    vshape = resource_variable_ops.variable_shape(v.handle)
     self.assertAllEqual(
-        tensor_util.constant_value(
-            resource_variable_ops.variable_shape(v.handle)),
+        tensor_util.constant_value(vshape),
         [2])
+    if not context.executing_eagerly():
+      self.assertEqual("Const", vshape.op.type)
 
   @test_util.run_deprecated_v1
   def testDifferentAssignGraph(self):
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index 01b324f..7fa31d1 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -3062,6 +3062,8 @@
 
 
 @test_util.run_all_in_graph_and_eager_modes
+@test_util.run_all_without_tensor_float_32(
+    "Uses an LSTMCell, which calls matmul")
 class DropoutWrapperTest(test.TestCase, parameterized.TestCase):
 
   def _testDropoutWrapper(self,
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index 5be7cb4..40f8b31 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -38,6 +38,7 @@
   setattr(test_class, test_name, fn)
 
 
+@test_util.run_all_without_tensor_float_32
 class SelfAdjointEigTest(test.TestCase):
 
   @test_util.run_deprecated_v1
@@ -160,8 +161,8 @@
         tf_e, tf_v = linalg_ops.self_adjoint_eig(constant_op.constant(a))
 
         # Check that V*diag(E)*V^T is close to A.
-        a_ev = math_ops.matmul(
-            math_ops.matmul(tf_v, array_ops.matrix_diag(tf_e)),
+        a_ev = test_util.matmul_without_tf32(
+            test_util.matmul_without_tf32(tf_v, array_ops.matrix_diag(tf_e)),
             tf_v,
             adjoint_b=True)
         self.assertAllClose(self.evaluate(a_ev), a, atol=atol)
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index a031f9b..368a7f1 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -165,6 +165,7 @@
     return a, b, a_dims, b_dims
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul")
   def test_tensordot(self):
     if dynamic_shape_ and context.executing_eagerly():
       self.skipTest("Placeholders not support in eager mode")
@@ -196,6 +197,7 @@
       self.assertAllEqual(tf_ans.shape, np_ans.shape)
 
   @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  @test_util.run_without_tensor_float_32("Tests tensordot, which calls matmul")
   def test_tensordot_scalar_axes(self):
     if dynamic_shape_ and context.executing_eagerly():
       self.skipTest("Placeholders not support in eager mode")
diff --git a/tensorflow/python/lib/core/safe_ptr.cc b/tensorflow/python/lib/core/safe_ptr.cc
index 2194f24..ce852a4 100644
--- a/tensorflow/python/lib/core/safe_ptr.cc
+++ b/tensorflow/python/lib/core/safe_ptr.cc
@@ -17,10 +17,6 @@
 
 namespace tensorflow {
 
-Safe_PyObjectPtr make_safe(PyObject* object) {
-  return Safe_PyObjectPtr(object);
-}
-
 Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) {
   return Safe_TF_TensorPtr(tensor);
 }
diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h
index 44d14e9..00f47d7 100644
--- a/tensorflow/python/lib/core/safe_ptr.h
+++ b/tensorflow/python/lib/core/safe_ptr.h
@@ -16,20 +16,17 @@
 #ifndef TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
 #define TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
 
-#include <memory>
-
 #include <Python.h>
 
+#include <memory>
+
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
 
 namespace tensorflow {
 namespace detail {
 
-struct PyDecrefDeleter {
-  void operator()(PyObject* p) const { Py_DECREF(p); }
-};
-
 struct TFTensorDeleter {
   void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
 };
@@ -48,11 +45,6 @@
 
 }  // namespace detail
 
-// Safe container for an owned PyObject. On destruction, the reference count of
-// the contained object will be decremented.
-using Safe_PyObjectPtr = std::unique_ptr<PyObject, detail::PyDecrefDeleter>;
-Safe_PyObjectPtr make_safe(PyObject* o);
-
 // Safe containers for an owned TF_Tensor. On destruction, the tensor will be
 // deleted by TF_DeleteTensor.
 using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, detail::TFTensorDeleter>;
diff --git a/tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc b/tensorflow/python/lib/core/safe_pyobject_ptr.cc
similarity index 69%
rename from tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc
rename to tensorflow/python/lib/core/safe_pyobject_ptr.cc
index 44ce384..966d3ec 100644
--- a/tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc
+++ b/tensorflow/python/lib/core/safe_pyobject_ptr.cc
@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -13,7 +13,12 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
 
-// Static initialization for TensorFlow.js op registration.
-static mlir::DialectRegistration<mlir::tfjs::TFJSDialect> tfjs_ops;
+namespace tensorflow {
+
+Safe_PyObjectPtr make_safe(PyObject* object) {
+  return Safe_PyObjectPtr(object);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/python/lib/core/safe_pyobject_ptr.h b/tensorflow/python/lib/core/safe_pyobject_ptr.h
new file mode 100644
index 0000000..496bfed
--- /dev/null
+++ b/tensorflow/python/lib/core/safe_pyobject_ptr.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_SAFE_PYOBJECT_PTR_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_SAFE_PYOBJECT_PTR_H_
+
+#include <Python.h>
+
+#include <memory>
+
+namespace tensorflow {
+namespace detail {
+
+struct PyDecrefDeleter {
+  void operator()(PyObject* p) const { Py_DECREF(p); }
+};
+
+}  // namespace detail
+
+// Safe container for an owned PyObject. On destruction, the reference count of
+// the contained object will be decremented.
+using Safe_PyObjectPtr = std::unique_ptr<PyObject, detail::PyDecrefDeleter>;
+Safe_PyObjectPtr make_safe(PyObject* o);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_PYTHON_LIB_CORE_SAFE_PYOBJECT_PTR_H_
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 5d68deb..4a2d04d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -4489,6 +4489,23 @@
   <tf.Tensor: shape=(4,), dtype=int32, numpy=array([100, 100, 100, 100],
   dtype=int32)>
 
+  Note that if the gradient of either branch of the tf.where generates
+  a NaN, then the gradient of the entire tf.where will be NaN.
+  A workaround is to use an inner tf.where to ensure the function has
+  no asymptote, and to avoid computing a value whose gradient is NaN by
+  replacing dangerous inputs with safe inputs.
+
+  Instead of this,
+
+  >>> y = tf.constant(-1, dtype=tf.float32)
+  >>> tf.where(y > 0, tf.sqrt(y), y)
+  <tf.Tensor: shape=(), dtype=float32, numpy=-1.0>
+
+  Use this
+
+  >>> tf.where(y > 0, tf.sqrt(tf.where(y > 0, y, 1)), y)
+  <tf.Tensor: shape=(), dtype=float32, numpy=-1.0>
+
   Args:
     condition: A `tf.Tensor` of type `bool`
     x: If provided, a Tensor which is of the same type as `y`, and has a shape
diff --git a/tensorflow/python/ops/collective_ops_multi_worker_test.py b/tensorflow/python/ops/collective_ops_multi_worker_test.py
new file mode 100644
index 0000000..4385a20
--- /dev/null
+++ b/tensorflow/python/ops/collective_ops_multi_worker_test.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multi worker Collective Operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import os
+import time
+
+from tensorflow.core.protobuf import tensorflow_server_pb2
+from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
+from tensorflow.python.distribute import multi_process_runner
+from tensorflow.python.distribute import multi_worker_test_base
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import collective_ops
+
+
+def enable_collective_ops(cluster_resolver):
+  context.context().configure_collective_ops(
+      collective_leader="/job:worker/replica:0/task:0")
+  config_proto = copy.deepcopy(context.context().config)
+  server_def = tensorflow_server_pb2.ServerDef(
+      cluster=cluster_resolver.cluster_spec().as_cluster_def(),
+      default_session_config=config_proto,
+      job_name=cluster_resolver.task_type,
+      task_index=cluster_resolver.task_id,
+      protocol=cluster_resolver.rpc_layer or "grpc")
+  context.context().enable_collective_ops(server_def)
+
+
+class CollectiveOpTest(test.TestCase):
+
+  def testCheckHealth(self):
+
+    def worker_fn():
+      enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
+      # There may be some delays before the server startup. Check health should
+      # eventually be OK.
+      while True:
+        try:
+          for task in [
+              "/job:worker/replica:0/task:0",
+              "/job:worker/replica:0/task:1",
+          ]:
+            context.context().check_collective_ops_peer_health(task)
+        except errors.UnavailableError:
+          continue
+        break
+      multi_process_runner.barrier().wait()
+
+    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
+    mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
+    mpr.start()
+    mpr.join()
+
+  def testCheckHealthPeerDown(self):
+
+    def worker_fn():
+      enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
+      context.context().check_collective_ops_peer_health(
+          "/job:worker/replica:0/task:1",)
+
+    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
+    mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
+    mpr.start_single_process("worker", 0)
+    with self.assertRaises(errors.UnavailableError):
+      mpr.join()
+
+  def testCheckHealthPeerRestart(self):
+
+    def worker_fn():
+      cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
+      enable_collective_ops(cluster_resolver)
+
+      collective_ops.all_reduce(
+          constant_op.constant(1.),
+          group_size=2,
+          group_key=100,
+          instance_key=100,
+          merge_op="Add",
+          final_op="Id",
+          communication_hint="ring")
+
+      if cluster_resolver.task_type == "worker":
+        # MultiProcessRunner will auto restart worker-0.
+        os._exit(1)  # pylint: disable=protected-access
+      else:
+        # chief should eventually gets FailedPreconditionError after worker-0
+        # has restarted.
+        while True:
+          time.sleep(1)
+          try:
+            context.context().check_collective_ops_peer_health(
+                "/job:worker/replica:0/task:0",)
+          except errors.UnavailableError:
+            pass
+          except errors.FailedPreconditionError:
+            break
+
+    cluster_spec = multi_worker_test_base.create_cluster_spec(
+        has_chief=True, num_workers=1)
+    mpr = multi_process_runner.MultiProcessRunner(
+        worker_fn, cluster_spec, auto_restart=True)
+    mpr.start()
+    mpr.join()
+
+  def testCheckHealthInvalidPeer(self):
+
+    def worker_fn():
+      enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
+      context.context().check_collective_ops_peer_health("localhost:12345",)
+
+    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
+    mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
+    mpr.start_single_process("worker", 0)
+    with self.assertRaises(errors.InvalidArgumentError):
+      mpr.join()
+
+
+if __name__ == "__main__":
+  multi_process_runner.test_main()
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 17a5d5e..163f0fb 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -25,7 +25,6 @@
 
 import collections
 
-from tensorflow.python.compat import compat
 from tensorflow.python.eager import backprop_util
 from tensorflow.python.framework import auto_control_deps
 from tensorflow.python.framework import auto_control_deps_utils as acd
@@ -1120,10 +1119,7 @@
         op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
     ])
 
-  # TODO(b/161915509): Remove this after 08/20/2020. This is required to abide
-  # by 3-week forward compat window of new TF python op generating code with
-  # stale runtime binaries.
-  if (stateful_ops or not compat.forward_compatible(2020, 8, 20)):
+  if stateful_ops:
     op_fn = gen_functional_ops.case
   else:
     op_fn = gen_functional_ops.stateless_case
diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py
index c2ee2b0..3da536c 100644
--- a/tensorflow/python/ops/image_grad_test.py
+++ b/tensorflow/python/ops/image_grad_test.py
@@ -184,20 +184,21 @@
     out_shape = [1, 2, 3, 1]
     x = np.arange(0, 24).reshape(in_shape)
 
-    with self.cached_session() as sess:
-      for dtype in [np.float16, np.float32, np.float64]:
-        input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape)
-        resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
-        grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0]
-        self.assertAllEqual(in_shape, grad.shape)
-        # Not using gradient_checker.compute_gradient as I didn't work out
-        # the changes required to compensate for the lower precision of
-        # float16 when computing the numeric jacobian.
-        # Instead, we just test the theoretical jacobian.
-        self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]], [[0.], [
-            0.
-        ], [0.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [1.], [0.]],
-                              [[0.], [0.], [0.], [0.], [0.], [0.]]]], grad)
+    for use_gpu in [False, True]:
+      with self.cached_session(use_gpu=use_gpu) as sess:
+        for dtype in [np.float16, np.float32, np.float64]:
+          input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape)
+          resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
+          grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0]
+          self.assertAllEqual(in_shape, grad.shape)
+          # Not using gradient_checker.compute_gradient as I didn't work out
+          # the changes required to compensate for the lower precision of
+          # float16 when computing the numeric jacobian.
+          # Instead, we just test the theoretical jacobian.
+          self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]],
+                                [[0.], [0.], [0.], [0.], [0.], [0.]],
+                                [[1.], [0.], [1.], [0.], [1.], [0.]],
+                                [[0.], [0.], [0.], [0.], [0.], [0.]]]], grad)
 
 
 class ResizeBicubicOpTest(test.TestCase):
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 1e737c1..1c8d8d6 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -97,6 +97,8 @@
 
 class RGBToYIQTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_without_tensor_float_32(
+      "Calls rgb_to_yiq and yiq_to_rgb, which use matmul")
   def testBatch(self):
     # Build an arbitrary RGB image
     np.random.seed(7)
@@ -127,6 +129,8 @@
 
 class RGBToYUVTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_without_tensor_float_32(
+      "Calls rgb_to_yuv and yuv_to_rgb, which use matmul")
   def testBatch(self):
     # Build an arbitrary RGB image
     np.random.seed(7)
@@ -1564,12 +1568,12 @@
       y_tf = self._adjustContrastTf(x_np, contrast_factor)
       self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5)
 
-  @test_util.run_deprecated_v1
   def testContrastFactorShape(self):
     x_shape = [1, 2, 2, 3]
     x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
     x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
-    with self.assertRaisesRegex(ValueError,
+    with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
+                                "contrast_factor must be scalar|"
                                 "Shape must be rank 0 but is rank 1"):
       image_ops.adjust_contrast(x_np, [2.0])
 
@@ -1637,7 +1641,6 @@
     y /= stddev
     return y
 
-  @test_util.run_deprecated_v1
   def testBasic(self):
     x_shape = [13, 9, 3]
     x_np = np.arange(0, np.prod(x_shape), dtype=np.float32).reshape(x_shape)
@@ -1646,7 +1649,6 @@
     with self.cached_session(use_gpu=True):
       x = constant_op.constant(x_np, shape=x_shape)
       y = image_ops.per_image_standardization(x)
-      self.assertTrue(y.op.name.startswith("per_image_standardization"))
       y_tf = self.evaluate(y)
       self.assertAllClose(y_tf, y_np, atol=1e-4)
 
@@ -1873,7 +1875,6 @@
     else:
       self.assertEqual(y.get_shape().as_list(), post_shape)
 
-  @test_util.run_deprecated_v1
   def testNoOp(self):
     x_shapes = [[13, 9, 3], [5, 13, 9, 3]]
     for x_shape in x_shapes:
@@ -1884,7 +1885,6 @@
           y = image_ops.central_crop(x, 1.0)
           y_tf = self.evaluate(y)
           self.assertAllEqual(y_tf, x_np)
-          self.assertEqual(y.op.name, x.op.name)
 
   def testCropping(self):
     x_shape = [4, 8, 1]
@@ -1917,7 +1917,6 @@
       self.assertAllEqual(y_tf, y_np)
       self.assertAllEqual(y_tf.shape, y_np.shape)
 
-  @test_util.run_deprecated_v1
   def testCropping2(self):
     # Test case for 10315
     x_shapes = [[240, 320, 3], [5, 240, 320, 3]]
@@ -1928,51 +1927,50 @@
       y_np = np.zeros(y_shape, dtype=np.int32)
       for use_gpu in [True, False]:
         with self.cached_session(use_gpu=use_gpu):
-          x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32)
-          y = image_ops.central_crop(x, 0.33)
-          y_tf = y.eval(feed_dict={x: x_np})
+          y_tf = self.evaluate(image_ops.central_crop(x_np, 0.33))
           self.assertAllEqual(y_tf, y_np)
           self.assertAllEqual(y_tf.shape, y_np.shape)
 
-  @test_util.run_deprecated_v1
   def testShapeInference(self):
-    # Test no-op fraction=1.0, with 3-D tensors.
-    self._assertShapeInference([50, 60, 3], 1.0, [50, 60, 3])
-    self._assertShapeInference([None, 60, 3], 1.0, [None, 60, 3])
-    self._assertShapeInference([50, None, 3], 1.0, [50, None, 3])
-    self._assertShapeInference([None, None, 3], 1.0, [None, None, 3])
-    self._assertShapeInference([50, 60, None], 1.0, [50, 60, None])
-    self._assertShapeInference([None, None, None], 1.0, [None, None, None])
+    # Shape function requires placeholders and a graph.
+    with ops.Graph().as_default():
+      # Test no-op fraction=1.0, with 3-D tensors.
+      self._assertShapeInference([50, 60, 3], 1.0, [50, 60, 3])
+      self._assertShapeInference([None, 60, 3], 1.0, [None, 60, 3])
+      self._assertShapeInference([50, None, 3], 1.0, [50, None, 3])
+      self._assertShapeInference([None, None, 3], 1.0, [None, None, 3])
+      self._assertShapeInference([50, 60, None], 1.0, [50, 60, None])
+      self._assertShapeInference([None, None, None], 1.0, [None, None, None])
 
-    # Test no-op fraction=0.5, with 3-D tensors.
-    self._assertShapeInference([50, 60, 3], 0.5, [26, 30, 3])
-    self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3])
-    self._assertShapeInference([50, None, 3], 0.5, [26, None, 3])
-    self._assertShapeInference([None, None, 3], 0.5, [None, None, 3])
-    self._assertShapeInference([50, 60, None], 0.5, [26, 30, None])
-    self._assertShapeInference([None, None, None], 0.5, [None, None, None])
+      # Test no-op fraction=0.5, with 3-D tensors.
+      self._assertShapeInference([50, 60, 3], 0.5, [26, 30, 3])
+      self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3])
+      self._assertShapeInference([50, None, 3], 0.5, [26, None, 3])
+      self._assertShapeInference([None, None, 3], 0.5, [None, None, 3])
+      self._assertShapeInference([50, 60, None], 0.5, [26, 30, None])
+      self._assertShapeInference([None, None, None], 0.5, [None, None, None])
 
-    # Test no-op fraction=1.0, with 4-D tensors.
-    self._assertShapeInference([5, 50, 60, 3], 1.0, [5, 50, 60, 3])
-    self._assertShapeInference([5, None, 60, 3], 1.0, [5, None, 60, 3])
-    self._assertShapeInference([5, 50, None, 3], 1.0, [5, 50, None, 3])
-    self._assertShapeInference([5, None, None, 3], 1.0, [5, None, None, 3])
-    self._assertShapeInference([5, 50, 60, None], 1.0, [5, 50, 60, None])
-    self._assertShapeInference([5, None, None, None], 1.0,
-                               [5, None, None, None])
-    self._assertShapeInference([None, None, None, None], 1.0,
-                               [None, None, None, None])
+      # Test no-op fraction=1.0, with 4-D tensors.
+      self._assertShapeInference([5, 50, 60, 3], 1.0, [5, 50, 60, 3])
+      self._assertShapeInference([5, None, 60, 3], 1.0, [5, None, 60, 3])
+      self._assertShapeInference([5, 50, None, 3], 1.0, [5, 50, None, 3])
+      self._assertShapeInference([5, None, None, 3], 1.0, [5, None, None, 3])
+      self._assertShapeInference([5, 50, 60, None], 1.0, [5, 50, 60, None])
+      self._assertShapeInference([5, None, None, None], 1.0,
+                                 [5, None, None, None])
+      self._assertShapeInference([None, None, None, None], 1.0,
+                                 [None, None, None, None])
 
-    # Test no-op fraction=0.5, with 4-D tensors.
-    self._assertShapeInference([5, 50, 60, 3], 0.5, [5, 26, 30, 3])
-    self._assertShapeInference([5, None, 60, 3], 0.5, [5, None, 30, 3])
-    self._assertShapeInference([5, 50, None, 3], 0.5, [5, 26, None, 3])
-    self._assertShapeInference([5, None, None, 3], 0.5, [5, None, None, 3])
-    self._assertShapeInference([5, 50, 60, None], 0.5, [5, 26, 30, None])
-    self._assertShapeInference([5, None, None, None], 0.5,
-                               [5, None, None, None])
-    self._assertShapeInference([None, None, None, None], 0.5,
-                               [None, None, None, None])
+      # Test no-op fraction=0.5, with 4-D tensors.
+      self._assertShapeInference([5, 50, 60, 3], 0.5, [5, 26, 30, 3])
+      self._assertShapeInference([5, None, 60, 3], 0.5, [5, None, 30, 3])
+      self._assertShapeInference([5, 50, None, 3], 0.5, [5, 26, None, 3])
+      self._assertShapeInference([5, None, None, 3], 0.5, [5, None, None, 3])
+      self._assertShapeInference([5, 50, 60, None], 0.5, [5, 26, 30, None])
+      self._assertShapeInference([5, None, None, None], 0.5,
+                                 [5, None, None, None])
+      self._assertShapeInference([None, None, None, None], 0.5,
+                                 [None, None, None, None])
 
   def testErrorOnInvalidCentralCropFractionValues(self):
     x_shape = [13, 9, 3]
@@ -1995,14 +1993,15 @@
           with self.assertRaises(ValueError):
             _ = image_ops.central_crop(x, 0.5)
 
-  @test_util.run_deprecated_v1
   def testNameScope(self):
-    x_shape = [13, 9, 3]
-    x_np = np.ones(x_shape, dtype=np.float32)
-    for use_gpu in [True, False]:
-      with self.cached_session(use_gpu=use_gpu):
-        y = image_ops.central_crop(x_np, 1.0)
-        self.assertTrue(y.op.name.startswith("central_crop"))
+    # Testing name scope requires a graph.
+    with ops.Graph().as_default():
+      x_shape = [13, 9, 3]
+      x_np = np.ones(x_shape, dtype=np.float32)
+      for use_gpu in [True, False]:
+        with self.cached_session(use_gpu=use_gpu):
+          y = image_ops.central_crop(x_np, 1.0)
+          self.assertTrue(y.op.name.startswith("central_crop"))
 
 
 class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/map_ops.py b/tensorflow/python/ops/map_ops.py
index ce8b7a6..7315e7e 100644
--- a/tensorflow/python/ops/map_ops.py
+++ b/tensorflow/python/ops/map_ops.py
@@ -46,6 +46,11 @@
 def tensor_map_has_key(input_handle, key):
   return gen_map_ops.tensor_map_has_key(input_handle, key)
 
+
+def tensor_map_stack_keys(input_handle, key_dtype):
+  return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype)
+
+
 @ops.RegisterGradient("TensorMapLookup")
 def LookupGrad(op, dval):
   _, k = op.inputs
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 9b864be..7f3d9f6 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -541,6 +541,8 @@
       _ = nn_ops.dropout(x, 0.5)
 
 
+@test_util.run_all_without_tensor_float_32(
+    "Tests _compute_sampled_logits and related functions, which call matmul")
 class ComputeSampledLogitsTest(test_lib.TestCase):
 
   def setUp(self):
diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py
index 9a85904..ade758d 100644
--- a/tensorflow/python/ops/numpy_ops/np_arrays.py
+++ b/tensorflow/python/ops/numpy_ops/np_arrays.py
@@ -294,6 +294,19 @@
   # NOTE: we currently prefer interop with TF to allow TF to take precedence.
   __array_priority__ = 90
 
+  def __array_module__(self, types):
+    # Experimental support for NumPy's module dispatch with NEP-37:
+    # https://numpy.org/neps/nep-0037-array-module.html
+    # Currently requires https://github.com/seberg/numpy-dispatch
+
+    # pylint: disable=g-import-not-at-top
+    import tensorflow.compat.v2 as tf
+
+    if all(issubclass(t, (ndarray, np.ndarray)) for t in types):
+      return tf.experimental.numpy
+    else:
+      return NotImplemented
+
   def __index__(self):
     """Returns a python scalar.
 
diff --git a/tensorflow/python/ops/numpy_ops/np_interop_test.py b/tensorflow/python/ops/numpy_ops/np_interop_test.py
index 3b52ae5..8999c8f 100644
--- a/tensorflow/python/ops/numpy_ops/np_interop_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_interop_test.py
@@ -174,6 +174,17 @@
     self.assertIsInstance(sq, onp.ndarray)
     self.assertEqual(100., sq[0])
 
+  def testArrayModule(self):
+    arr = np.asarray([10])
+
+    module = arr.__array_module__((np.ndarray,))
+    self.assertIs(module, tf.experimental.numpy)
+
+    class Dummy:
+      pass
+    module = arr.__array_module__((np.ndarray, Dummy))
+    self.assertIs(module, NotImplemented)
+
     # TODO(nareshmodi): Fails since the autopacking code doesn't use
     # nest.flatten.
 
@@ -323,6 +334,11 @@
     self.assertIsInstance(c, np.ndarray)
     self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32))
 
+    c = tf.vectorized_map(lambda x: x.T, a)
+
+    self.assertIsInstance(c, np.ndarray)
+    self.assertEqual(c.shape, (batch_size, 32, 32))
+
   def testJacobian(self):
     with tf.GradientTape() as g:
       x = np.asarray([1., 2.])
diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py
index c1505e6..631975c 100644
--- a/tensorflow/python/ops/numpy_ops/np_math_ops.py
+++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py
@@ -217,6 +217,9 @@
             *np_utils.tf_broadcast(a.data, a_min.data, a_max.data)))
 
 
+setattr(np_arrays.ndarray, 'clip', clip)
+
+
 @np_utils.np_doc('matmul')
 def matmul(x1, x2):  # pylint: disable=missing-docstring
   def f(x1, x2):
diff --git a/tensorflow/python/ops/parallel_for/array_test.py b/tensorflow/python/ops/parallel_for/array_test.py
index 1e2ecdb..d449050 100644
--- a/tensorflow/python/ops/parallel_for/array_test.py
+++ b/tensorflow/python/ops/parallel_for/array_test.py
@@ -59,6 +59,11 @@
         outputs.append(array_ops.gather(y, [i, 1, 2], axis=2, batch_dims=1))
         outputs.append(array_ops.gather(y, [[2, i], [i, 1], [2, 1]],
                                         axis=-1, batch_dims=1))
+        outputs.append(
+            array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=2))
+        outputs.append(array_ops.gather(y, [0, 1, 2], axis=1, batch_dims=-1))
+        outputs.append(
+            array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=-2))
 
       return outputs
 
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index e7a5c38..b60bc21 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -357,7 +357,10 @@
     i = 0
   elif static_first_dim is None:
     i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0)
-  return array_ops.gather(x, i)
+  result = array_ops.gather(x, i)
+  if isinstance(x, np_arrays.ndarray):
+    result = np_arrays.ndarray.from_tensor(result)
+  return result
 
 
 @tf_export("vectorized_map")
@@ -450,7 +453,11 @@
   Raises:
     ValueError: If vectorization fails and fallback_to_while_loop is False.
   """
-  elems = nest.map_structure(ops.convert_to_tensor, elems)
+  def _convert_to_tensor_or_ndarray(x):
+    if isinstance(x, np_arrays.ndarray):
+      return x
+    return ops.convert_to_tensor(x)
+  elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
 
   def loop_fn(i):
     gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
@@ -459,9 +466,13 @@
 
   # Extract batch size from the maximum first dimension of any element.
   flat_elems = nest.flatten(elems)
-  static_first_dims = [elem.shape.as_list()[0]
-                       if elem.shape.rank is not None else None
-                       for elem in flat_elems]
+  def _get_shape(x):
+    if isinstance(x, np_arrays.ndarray):
+      x = x.data
+    if x.shape.rank is None:
+      return None
+    return x.shape.as_list()[0]
+  static_first_dims = [_get_shape(elem) for elem in flat_elems]
   if any([s is None for s in static_first_dims]):
     batch_size = math_ops.reduce_max(
         [array_ops.shape(elem)[0] for elem in flat_elems])
diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py
index 85b5805..30e7244 100644
--- a/tensorflow/python/ops/parallel_for/math_test.py
+++ b/tensorflow/python/ops/parallel_for/math_test.py
@@ -261,6 +261,9 @@
 
     self._test_loop_fn(loop_fn, 4)
 
+  @test_util.run_without_tensor_float_32(
+      "Calls matmul in parallel for-loop and compares result to calling matmul "
+      "in sequential for-loop")
   def test_matmul(self):
     for tr_a in (True, False):
       for tr_b in (True, False):
@@ -745,6 +748,9 @@
 
     self._test_loop_fn(loop_fn, 2)
 
+  @test_util.run_without_tensor_float_32(
+      "Calls einsum in parallel for-loop and compares result to calling einsum "
+      "in sequential for-loop")
   def test_einsum(self):
     b = 10
     x_series = random_ops.random_uniform([b, 9, 9])
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index d14ad1e..cde1e6a 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -2275,7 +2275,11 @@
         # it must be picking up all the rows of param.
         return wrap(param, True)
 
-    if batch_dims > 0:
+    if batch_dims != 0:
+      # Convert `batch_dims` to its positive equivalent if necessary.
+      batch_dims_pos = batch_dims
+      if batch_dims < 0:
+        batch_dims_pos += array_ops.rank(indices)
       # In order to maintain
       #   indices.shape[:batch_dims] == params.shape[:batch_dims]
       # with stacked indices, we move the first dimension of `indices` to the
@@ -2283,8 +2287,9 @@
       # inserted into the shape of `output` at the `axis` dimension, which is
       # then transposed to the front (below).
       order = array_ops.concat([
-          (list(range(1, batch_dims + 1)) + [0]),
-          math_ops.range(batch_dims + 1, array_ops.rank(indices))], axis=0)
+          math_ops.range(1, batch_dims_pos + 1),
+          [0],
+          math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0)
       indices = array_ops.transpose(indices, order)
 
     output = array_ops.gather(
@@ -2310,7 +2315,7 @@
     output = array_ops.gather(
         param, indices,
         axis=array_ops.where(axis >= 0, axis + 1, axis),
-        batch_dims=batch_dims + 1)
+        batch_dims=(batch_dims + 1 if batch_dims >= 0 else batch_dims))
     return wrap(output, True)
 
 
diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
index 8a40e39..bead492 100644
--- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py
@@ -150,6 +150,21 @@
           result_dtype=ragged_tensor.RaggedTensorType(
               dtype=dtypes.int64, ragged_rank=4),
       ),
+      # [d1] -> [d1, (d2), (d3)]
+      dict(
+          fn=ragged_math_ops.range,
+          elems=np.array([1, 2, 3], np.int64),
+          expected_output=[[[0]], [[0, 1]], [[0, 1, 2]]],
+          result_dtype=ragged_tensor.RaggedTensorType(
+              dtype=dtypes.int64, ragged_rank=2)),
+      # [0] -> [0, (d2), (d3)]  (github issue #36232)
+      dict(
+          fn=ragged_math_ops.range,
+          elems=np.zeros([0], np.int64),
+          expected_output=[],
+          expected_ragged_rank=2,
+          result_dtype=ragged_tensor.RaggedTensorType(
+              dtype=dtypes.int64, ragged_rank=2)),
   ])
 
   def testRaggedMap(
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index 623f506..ba184b2 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -635,6 +635,8 @@
 
 
 @test_util.run_all_in_graph_and_eager_modes
+@test_util.run_all_without_tensor_float_32(
+    'Tests einsum, which sometimes does a matmul with cuBLAS')
 class EinsumTest(test.TestCase):
 
   def _check(self, s, *input_shapes, **kwargs):
diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD
index 98e9336..5eeaf96 100644
--- a/tensorflow/python/profiler/internal/BUILD
+++ b/tensorflow/python/profiler/internal/BUILD
@@ -142,7 +142,10 @@
     visibility = ["//visibility:public"],
     deps = [
         "//tensorflow/core:lib",
-        "//tensorflow/core/profiler/lib:traceme",
+        "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+        "//tensorflow/core/profiler/utils:xplane_builder",
+        "//tensorflow/core/profiler/utils:xplane_schema",
+        "//tensorflow/core/profiler/utils:xplane_utils",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/strings",
         "@pybind11",
diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc
index 4012010..7ddce91 100644
--- a/tensorflow/python/profiler/internal/profiler_wrapper.cc
+++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc
@@ -16,6 +16,7 @@
 #include <memory>
 
 #include "absl/memory/memory.h"
+#include "absl/strings/match.h"
 #include "absl/strings/numbers.h"
 #include "pybind11/pybind11.h"
 #include "pybind11/pytypes.h"
@@ -50,7 +51,7 @@
   // Must be host:port, port must be a number, host must not contain a '/',
   // host also must not be empty.
   if (parts.size() != 2 || !absl::SimpleAtoi(parts[1], &port) ||
-      parts[0].find("/") != std::string::npos || parts[0].empty()) {
+      absl::StrContains(parts[0], "/") || parts[0].empty()) {
     return tensorflow::errors::InvalidArgument(
         "Could not interpret \"", host_port, "\" as a host-port pair.");
   }
@@ -123,7 +124,8 @@
       .def("export_to_tb", &ProfilerSessionWrapper::ExportToTensorBoard);
 
   m.def("start_server", [](int port) {
-    auto profiler_server = absl::make_unique<tensorflow::ProfilerServer>();
+    auto profiler_server =
+        absl::make_unique<tensorflow::profiler::ProfilerServer>();
     profiler_server->StartProfilerServer(port);
     // Intentionally release profiler server. Should transfer ownership to
     // caller instead.
diff --git a/tensorflow/python/profiler/internal/python_hooks.cc b/tensorflow/python/profiler/internal/python_hooks.cc
index ee2ad1e..aa59305 100644
--- a/tensorflow/python/profiler/internal/python_hooks.cc
+++ b/tensorflow/python/profiler/internal/python_hooks.cc
@@ -16,13 +16,19 @@
 
 #include "absl/strings/string_view.h"
 #include "absl/strings/strip.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/path.h"
+#include "tensorflow/core/profiler/utils/xplane_builder.h"
+#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_utils.h"
 
 namespace tensorflow {
 namespace profiler {
 
 namespace py = ::pybind11;
 
+namespace {
+
 template <typename T>
 int ProfileFunction(PyObject* obj, PyFrameObject* frame, int what,
                     PyObject* arg) {
@@ -40,40 +46,153 @@
   setprofile(callback);
 }
 
+std::string GetEventName(PyCodeObject* py_code) {
+  string filename(py::reinterpret_borrow<py::str>(py_code->co_filename));
+  string function;
+  if (py_code->co_name == nullptr) {
+    function = "<unknown>";
+  } else {
+    function = py::reinterpret_borrow<py::str>(py_code->co_name);
+  }
+
+  return absl::StrCat("$", io::Basename(filename), ":", py_code->co_firstlineno,
+                      " ", function);
+}
+
+string GetEventName(PyCFunctionObject* py_cfunc) {
+  PyObject* module = py_cfunc->m_module;
+  string filename;
+  bool filename_ok;
+#if PY_MAJOR_VERSION < 3
+  filename_ok = (module != nullptr && PyString_Check(module));
+#else
+  filename_ok = (module != nullptr && PyUnicode_Check(module));
+#endif
+  if (filename_ok) {
+    filename = py::reinterpret_borrow<py::str>(module);
+  } else {
+    filename = "<unknown>";
+  }
+
+  return absl::StrCat("$", filename, " ", py_cfunc->m_ml->ml_name);
+}
+
+void AddEventToXLine(const PythonTraceEntry& event, XLineBuilder* line,
+                     XPlaneBuilder* plane) {
+  // TODO(jiesun): maybe add full filename as event stats.
+  auto xevent = line->AddEvent(*plane->GetOrCreateEventMetadata(event.Name()));
+  xevent.SetTimestampNs(event.start_time_ns);
+  xevent.SetEndTimestampNs(event.end_time_ns);
+}
+
+}  // namespace
+
+std::string PythonTraceEntry::Name() const {
+  std::string event_name;
+  if (code_object) {
+    return GetEventName(code_object);
+  } else if (function_object) {
+    return GetEventName(function_object);
+  }
+  return "<unknown>";
+}
+
 PythonHooks* PythonHooks::GetSingleton() {
   static PythonHooks* singleton = new PythonHooks;
   return singleton;
 }
 
-void PythonHooks::Start(const PythonHooksOptions& option) {
+void PythonHooks::Start(const PythonHooksOptions& options) {
   if (!Py_IsInitialized()) return;
-  if (option.enable_python_traceme || option.enable_trace_python_function) {
+  options_ = options;
+  start_timestamp_ns_ = EnvTime::NowNanos();
+  if (options_.enable_python_traceme || options_.enable_trace_python_function) {
     PyGILState_STATE gil_state = PyGILState_Ensure();
-    if (option.enable_trace_python_function) {
+    if (options_.enable_trace_python_function) {
       SetProfilerInAllThreads();
     }
-    if (option.enable_python_traceme) {
+    if (options_.enable_python_traceme) {
       EnableTraceMe(true);
     }
+    if (options_.end_to_end_mode) {
+      // When end to end mode is used, Stop() and Finalize() i.e. symbolization
+      // and data collection happens during C's atexit(), when Py_FinalizeEx()
+      // already called.
+      try {
+        auto atexit = py::module::import("atexit");
+        atexit.attr("register")(py::cpp_function([]() {
+          PythonHooks* singleton = PythonHooks::GetSingleton();
+          singleton->Stop();
+          singleton->CollectData(&(singleton->end_to_end_xplane_.emplace()));
+        }));
+      } catch (const py::error_already_set& e) {
+        LOG(ERROR) << "Can't install atexit handler for e2e mode." << e.what();
+      }
+    }
     PyGILState_Release(gil_state);
+    active_session_ = true;
   }
 }
 
-void PythonHooks::Stop(const PythonHooksOptions& option) {
+void PythonHooks::Stop() {
   if (!Py_IsInitialized()) return;
-  if (option.enable_python_traceme || option.enable_trace_python_function) {
+  if (!active_session_) return;  // Makes sure Stop() can be reentrant.
+  if (options_.enable_python_traceme || options_.enable_trace_python_function) {
     PyGILState_STATE gil_state = PyGILState_Ensure();
-    if (option.enable_trace_python_function) {
+    if (options_.enable_trace_python_function) {
       ClearProfilerInAllThreads();
     }
-    if (option.enable_python_traceme) {
+    if (options_.enable_python_traceme) {
       EnableTraceMe(false);
     }
     PyGILState_Release(gil_state);
+    active_session_ = false;
   }
 }
 
-void PythonHooks::Finalize() { tracemes_.clear(); }
+void PythonHooks::CollectData(XPlane* raw_plane) {
+  DCHECK(raw_plane);
+  XPlaneBuilder plane(raw_plane);
+  for (auto& it : entries_) {
+    uint64 thread_id = it.first;
+    auto& thread_events = it.second;
+    VLOG(1) << "Collecting " << thread_events.completed.size() << ":"
+            << thread_events.active.size() << " events on thread " << thread_id;
+    auto line = plane.GetOrCreateLine(thread_id);
+    line.SetTimestampNs(start_timestamp_ns_);
+    for (const auto& event : thread_events.completed) {
+      AddEventToXLine(event, &line, &plane);
+    }
+    if (options_.include_incomplete_events) {
+      uint64 now = EnvTime::NowNanos();
+      while (!thread_events.active.empty()) {
+        auto& event = thread_events.active.top();
+        event.end_time_ns = now;
+        AddEventToXLine(event, &line, &plane);
+        thread_events.active.pop();
+      }
+    }
+  }
+  entries_.clear();
+}
+
+void PythonHooks::Finalize(XSpace* space) {
+  if (space) {
+    XPlane* plane =
+        FindOrAddMutablePlaneWithName(space, kPythonTracerPlaneName);
+    if (options_.end_to_end_mode) {
+      if (end_to_end_xplane_) {
+        end_to_end_xplane_->set_name(plane->name());
+        plane->Swap(&*end_to_end_xplane_);
+        end_to_end_xplane_.reset();
+      }
+    } else {
+      PyGILState_STATE gil_state = PyGILState_Ensure();
+      CollectData(plane);
+      PyGILState_Release(gil_state);
+    }
+  }
+}
 
 void PythonHooks::ProfileSlow(const py::object& frame, const string& event,
                               const py::object& arg) {
@@ -106,52 +225,58 @@
 }
 
 void PythonHooks::ProfileFast(PyFrameObject* frame, int what, PyObject* arg) {
-  const int64 thread_id = PyThread_get_thread_ident();
+  const int64 thread_id = Env::Default()->GetCurrentThreadId();
+  uint64 now = EnvTime::NowNanos();
+  auto& thread_traces = entries_[thread_id];
 
-  if (what == PyTrace_CALL) {
-    PyCodeObject* f_code = frame->f_code;
-    string filename(py::reinterpret_borrow<py::str>(f_code->co_filename));
-    int line_no = frame->f_lineno;
-
-    string function;
-    if (f_code->co_name == nullptr) {
-      function = "<unknown>";
-    } else {
-      function = py::reinterpret_borrow<py::str>(f_code->co_name);
+  switch (what) {
+    case PyTrace_CALL: {
+      PyCodeObject* f_code = frame->f_code;
+      thread_traces.active.emplace(now, 0, f_code, nullptr);
+      break;
     }
-
-    tracemes_[thread_id].push_back(
-        absl::make_unique<TraceMe>([&filename, line_no, &function] {
-          return absl::StrCat("$", io::Basename(filename), ":", line_no, " ",
-                              function);
-        }));
-  } else if (what == PyTrace_C_CALL && PyCFunction_Check(arg)) {
-    // Python stack does not have a filename/line_no for native calls.
-    auto* func = reinterpret_cast<PyCFunctionObject*>(arg);
-    PyObject* module = func->m_module;
-    string filename;
-    bool filename_ok;
-#if PY_MAJOR_VERSION < 3
-    filename_ok = (module != nullptr && PyString_Check(module));
-#else
-    filename_ok = (module != nullptr && PyUnicode_Check(module));
-#endif
-    if (filename_ok) {
-      filename = py::reinterpret_borrow<py::str>(module);
-    } else {
-      filename = "<unknown>";
+    case PyTrace_RETURN:
+    case PyTrace_EXCEPTION: {
+      if (!thread_traces.active.empty()) {
+        auto& entry = thread_traces.active.top();
+        entry.end_time_ns = now;
+        thread_traces.completed.emplace_back(std::move(entry));
+        thread_traces.active.pop();
+      } else if (options_.include_incomplete_events) {
+        PyCodeObject* f_code = frame->f_code;
+        thread_traces.completed.emplace_back(start_timestamp_ns_, now, f_code,
+                                             nullptr);
+      }
+      break;
     }
-
-    tracemes_[thread_id].push_back(
-        absl::make_unique<TraceMe>([&filename, func] {
-          return absl::StrCat(filename, " ", func->m_ml->ml_name);
-        }));
-  } else if (what == PyTrace_RETURN || what == PyTrace_C_RETURN ||
-             what == PyTrace_EXCEPTION || what == PyTrace_C_EXCEPTION) {
-    auto& thread_tracemes = tracemes_[thread_id];
-    if (!thread_tracemes.empty()) {
-      thread_tracemes.pop_back();
+    case PyTrace_C_CALL: {
+      if (PyCFunction_Check(arg)) {
+        // Python stack does not have a filename/line_no for native calls.
+        auto* func = reinterpret_cast<PyCFunctionObject*>(arg);
+        entries_[thread_id].active.emplace(now, 0, nullptr, func);
+      }
+      break;
     }
+    case PyTrace_C_RETURN:
+    case PyTrace_C_EXCEPTION: {
+      if (!thread_traces.active.empty()) {
+        auto& entry = thread_traces.active.top();
+        entry.end_time_ns = now;
+        thread_traces.completed.emplace_back(std::move(entry));
+        thread_traces.active.pop();
+      } else if (options_.include_incomplete_events) {
+        // Only the end of the events is recorded, use profiler start as start.
+        if (PyCFunction_Check(arg)) {
+          // Python stack does not have a filename/line_no for native calls.
+          auto* func = reinterpret_cast<PyCFunctionObject*>(arg);
+          entries_[thread_id].completed.emplace_back(start_timestamp_ns_, now,
+                                                     nullptr, func);
+        }
+      }
+      break;
+    }
+    default:
+      break;
   }
 }
 
diff --git a/tensorflow/python/profiler/internal/python_hooks.h b/tensorflow/python/profiler/internal/python_hooks.h
index 582edf4..b30fcc3 100644
--- a/tensorflow/python/profiler/internal/python_hooks.h
+++ b/tensorflow/python/profiler/internal/python_hooks.h
@@ -16,14 +16,16 @@
 #define TENSORFLOW_PYTHON_PROFILER_INTERNAL_PYTHON_HOOKS_H_
 
 #include <memory>
+#include <stack>
 #include <vector>
 
 #include "absl/container/flat_hash_map.h"
 #include "pybind11/cast.h"
 #include "pybind11/pybind11.h"
 #include "pybind11/pytypes.h"
+#include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 
 namespace tensorflow {
 namespace profiler {
@@ -33,6 +35,52 @@
 struct PythonHooksOptions {
   bool enable_trace_python_function = false;
   bool enable_python_traceme = true;
+  bool end_to_end_mode = false;
+  // Incomplete events are defined as those python calls which we only see
+  // either start or end, but not both. If we want to include them in the final
+  // result, profiler start, end time are used respectively to the absent
+  // timestamps.
+  bool include_incomplete_events = true;
+};
+
+struct PythonTraceEntry {
+  PythonTraceEntry(uint64 start, uint64 end, PyCodeObject* code,
+                   PyCFunctionObject* func)
+      : start_time_ns(start),
+        end_time_ns(end),
+        code_object(code),
+        function_object(func) {
+    Py_XINCREF(code_object);
+    Py_XINCREF(function_object);
+  }
+  ~PythonTraceEntry() {
+    Py_XDECREF(code_object);
+    Py_XDECREF(function_object);
+  }
+  PythonTraceEntry(PythonTraceEntry&& other) {
+    start_time_ns = other.start_time_ns;
+    end_time_ns = other.end_time_ns;
+    code_object = other.code_object;
+    function_object = other.function_object;
+    other.code_object = nullptr;
+    other.function_object = nullptr;
+  }
+
+  std::string Name() const;
+
+  uint64 start_time_ns;
+  uint64 end_time_ns;
+  PyCodeObject* code_object;
+  PyCFunctionObject* function_object;
+
+  PythonTraceEntry(const PythonTraceEntry& other) = delete;
+  void operator=(const PythonTraceEntry&) = delete;
+  void operator=(PythonTraceEntry&&) = delete;
+};
+
+struct PerThreadEvents {
+  std::deque<PythonTraceEntry> completed;
+  std::stack<PythonTraceEntry> active;
 };
 
 // Singleton for tracing python function calls.
@@ -41,19 +89,27 @@
   static PythonHooks* GetSingleton();
 
   void Start(const PythonHooksOptions& option);
-  void Stop(const PythonHooksOptions& option);
-  void Finalize();
+  void Stop();
+  void Finalize(XSpace* space);
   void ProfileSlow(const py::object& frame, const string& event,
                    const py::object& arg);
   void ProfileFast(PyFrameObject* frame, int what, PyObject* arg);
 
  private:
   void EnableTraceMe(bool enable);
+  void CollectData(XPlane* raw_plane);
 
   void SetProfilerInAllThreads();
   void ClearProfilerInAllThreads();
 
-  absl::flat_hash_map<int64, std::vector<std::unique_ptr<TraceMe>>> tracemes_;
+  // entries_ are accessed when GIL is held, therefore no race conditions.
+  absl::flat_hash_map<int64, PerThreadEvents> entries_;
+  uint64 start_timestamp_ns_;
+  bool active_session_ = false;
+  PythonHooksOptions options_;
+  // In end to end mode, Python get uninitialized before Stop()/Finalize(), we
+  // need to buffer the result.
+  absl::optional<XPlane> end_to_end_xplane_;
 };
 
 }  // namespace profiler
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index 54124df..e802e54 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -1800,6 +1800,8 @@
     root = tracking.AutoTrackable()
     root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
     root.table.insert("foo", 15)
+    root.table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
+    root.table2.insert("idk", 21)
 
     @def_function.function(
         input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 33780c1..361883a 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -757,6 +757,11 @@
     proto.variable.synchronization = obj.synchronization.value
     proto.variable.aggregation = obj.aggregation.value
     proto.variable.shape.CopyFrom(obj.shape.as_proto())
+    options = save_context.get_save_options()
+    if options.experimental_variable_policy._save_variable_devices(  # pylint: disable=protected-access
+    ):
+      if hasattr(obj, "device"):
+        proto.variable.device = obj.device
   elif isinstance(obj, def_function.Function):
     proto.function.CopyFrom(function_serialization.serialize_function(
         obj, function_name_map))
@@ -1005,8 +1010,8 @@
   utils_impl.get_or_create_variables_dir(export_dir)
   ckpt_options = checkpoint_options.CheckpointOptions(
       experimental_io_device=options.experimental_io_device)
-  object_saver.save(utils_impl.get_variables_path(export_dir),
-                    options=ckpt_options)
+  object_saver.save(
+      utils_impl.get_variables_path(export_dir), options=ckpt_options)
   builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
                                               export_dir)
   # Note that this needs to be the last file operation when saving the
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index c59b131..d74d190 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -514,12 +514,14 @@
     else:
       save.save(obj=root, export_dir=file_name, options=options)
 
-    graph_def = None
+    meta = None
     if meta_graph_only:
-      graph_def = meta_graph.read_meta_graph_file(file_name).graph_def
+      meta = meta_graph.read_meta_graph_file(file_name)
     else:
-      graph_def = loader_impl.parse_saved_model(
-          file_name).meta_graphs[0].graph_def
+      meta = loader_impl.parse_saved_model(file_name).meta_graphs[0]
+
+    # Check devices in meta graph nodes.
+    graph_def = meta.graph_def
     v0 = next((n for n in graph_def.node if n.name == "v0"), None)
     v1 = next((n for n in graph_def.node if n.name == "v1"), None)
     self.assertIsNotNone(v0)
@@ -531,6 +533,23 @@
       self.assertEmpty(v0.device)
       self.assertEmpty(v1.device)
 
+    # Check devices in object graph nodes.
+    object_graph_def = meta.object_graph_def
+    v0 = next((n.variable
+               for n in object_graph_def.nodes
+               if n.HasField("variable") and n.variable.name == "v0"), None)
+    v1 = next((n.variable
+               for n in object_graph_def.nodes
+               if n.HasField("variable") and n.variable.name == "v1"), None)
+    self.assertIsNotNone(v0)
+    self.assertIsNotNone(v1)
+    if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
+      self.assertIn("CPU:0", v0.device)
+      self.assertIn("CPU:1", v1.device)
+    else:
+      self.assertEmpty(v0.device)
+      self.assertEmpty(v1.device)
+
   @parameterized.named_parameters(
       ("_ExpandDistributedVariables",
        save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES),
diff --git a/tensorflow/python/tf_program/pywrap_tfd.py b/tensorflow/python/tf_program/pywrap_tfd.py
index 0d9a236..a7a30b7 100644
--- a/tensorflow/python/tf_program/pywrap_tfd.py
+++ b/tensorflow/python/tf_program/pywrap_tfd.py
@@ -137,8 +137,8 @@
   """Python wrap for a Tensorflow Program (essentially an mlir Module)."""
 
   def __init__(self):
-    mlir.registerDialects()
     self.ctx = mlir.MLIRContext()
+    mlir.preloadTensorFlowDialects(self.ctx)
     self.builder = mlir.Builder(self.ctx)
     self.module = mlir.ModuleOp.create(mlir.UnknownLoc.get(self.ctx))
     self.curr_func = None
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 0afd05e..8a843eb 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -375,11 +375,6 @@
         for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
           tensorflow::Device* device = devices[device_idx];
 
-          if (absl::StrContains(device->name(), "XLA") &&
-              !absl::StrContains(device_name, "XLA")) {
-            continue;
-          }
-
           if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
                   input_device_name, device->parsed_name())) {
             if (device->device_type() == tensorflow::DEVICE_CPU) {
@@ -387,13 +382,6 @@
                   "CPU does not support getting allocator information");
             }
 
-            if (absl::StrContains(device->device_type(), "XLA") &&
-                !absl::StrContains(device_name, "XLA")) {
-              // TODO(b/140134773): Remove this workaround.
-              // Do not accidentally match XLA devices.
-              continue;
-            }
-
             if (matched_device != nullptr) {
               tensorflow::ThrowValueError(
                   absl::StrFormat(
@@ -544,21 +532,13 @@
     return TFE_ContextGetDevicePlacementPolicy(
         tensorflow::InputTFE_Context(ctx));
   });
-  m.def("TFE_ContextGetMirroringPolicy", [](py::handle& ctx) {
-    return TFE_ContextGetMirroringPolicy(tensorflow::InputTFE_Context(ctx));
-  });
   m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
         [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
           TFE_ContextSetThreadLocalDevicePlacementPolicy(
               tensorflow::InputTFE_Context(ctx), policy);
         });
-  m.def("TFE_ContextSetThreadLocalMirroringPolicy",
-        [](py::handle& ctx, TFE_ContextMirroringPolicy policy) {
-          TFE_ContextSetThreadLocalMirroringPolicy(
-              tensorflow::InputTFE_Context(ctx), policy);
-        });
   m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
-                                      py::str proto) {
+                                      py::bytes proto) {
     tensorflow::Safe_TF_StatusPtr status =
         tensorflow::make_safe(TF_NewStatus());
     tensorflow::Safe_TF_BufferPtr buf =
@@ -568,7 +548,7 @@
     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
   });
   m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
-                                         py::str proto) {
+                                         py::bytes proto) {
     tensorflow::Safe_TF_StatusPtr status =
         tensorflow::make_safe(TF_NewStatus());
     tensorflow::Safe_TF_BufferPtr buf =
@@ -848,7 +828,7 @@
   m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
         py::return_value_policy::reference);
   m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
-                                          py::str proto) {
+                                          py::bytes proto) {
     tensorflow::Safe_TF_StatusPtr status =
         tensorflow::make_safe(TF_NewStatus());
     tensorflow::Safe_TF_BufferPtr buf =
@@ -862,8 +842,6 @@
   m.def("TFE_ContextOptionsSetLazyRemoteInputsCopy",
         &TFE_ContextOptionsSetLazyRemoteInputsCopy);
   m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
-  m.def("TFE_ContextOptionsSetMirroringPolicy",
-        &TFE_ContextOptionsSetMirroringPolicy);
   m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
   m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
         py::return_value_policy::reference);
@@ -899,7 +877,7 @@
           return tensorflow::PyoOrThrow(
               TFE_Py_EncodeArg(o.ptr(), include_tensor_ranks_only));
         });
-  m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::str proto) {
+  m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
     tensorflow::Safe_TF_StatusPtr status =
         tensorflow::make_safe(TF_NewStatus());
     tensorflow::Safe_TF_BufferPtr buf =
@@ -915,6 +893,14 @@
     TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
     TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
   });
+  m.def("TFE_CollectiveOpsCheckPeerHealth",
+        [](const py::handle& ctx, const char* task) {
+          tensorflow::Safe_TF_StatusPtr status =
+              tensorflow::make_safe(TF_NewStatus());
+          TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
+                                           task, status.get());
+          tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+        });
   m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
   m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
   m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
@@ -1312,9 +1298,4 @@
       .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
       .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
       .export_values();
-
-  py::enum_<TFE_ContextMirroringPolicy>(m, "TFE_ContextMirroringPolicy")
-      .value("TFE_MIRRORING_NONE", TFE_MIRRORING_NONE)
-      .value("TFE_MIRRORING_ALL", TFE_MIRRORING_ALL)
-      .export_values();
 };
diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py
index 7e0278a..421be25 100644
--- a/tensorflow/python/tpu/tpu_outside_compilation_test.py
+++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py
@@ -90,6 +90,10 @@
 
 class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
 
+  def setUp(self):
+    super(TpuOutsideCompilationTest, self).setUp()
+    config.set_soft_device_placement(False)
+
   def testResourceVariableAssignOnHost(self):
     strategy = get_tpu_strategy()
     with strategy.scope():
diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py
index c315d7c..d4ba15e 100644
--- a/tensorflow/python/tpu/tpu_strategy_util.py
+++ b/tensorflow/python/tpu/tpu_strategy_util.py
@@ -109,10 +109,6 @@
     context.context()._clear_caches()  # pylint: disable=protected-access
 
     serialized_topology = output.numpy()
-
-    # TODO(b/134094971): Remove this when lazy tensor copy in multi-device
-    # function has been implemented.
-    context.context().mirroring_policy = context.MIRRORING_ALL
   elif not ops.executing_eagerly_outside_functions():
     master = cluster_resolver.master()
     cluster_spec = cluster_resolver.cluster_spec()
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 332cc40..b6313f5 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -47,10 +47,12 @@
   def doTestBasic(self, use_resource=False, use_callable_params=False):
     for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
       if use_resource:
-        var0 = resource_variable_ops.ResourceVariable(
-            [1.0, 2.0], dtype=dtype, name="var0_%d" % i)
-        var1 = resource_variable_ops.ResourceVariable(
-            [3.0, 4.0], dtype=dtype, name="var1_%d" % i)
+        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
+                                                      dtype=dtype,
+                                                      name="var0_%d" % i)
+        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
+                                                      dtype=dtype,
+                                                      name="var1_%d" % i)
       else:
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
@@ -63,8 +65,7 @@
         momentum = momentum()
       mom_opt = momentum_lib.MomentumOptimizer(
           learning_rate=learning_rate, momentum=momentum)
-      mom_update = mom_opt.apply_gradients(
-          zip([grads0, grads1], [var0, var1]))
+      mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
 
       if not context.executing_eagerly():
         self.evaluate(variables.global_variables_initializer())
@@ -87,14 +88,13 @@
       if not context.executing_eagerly():
         self.evaluate(mom_update)
       # Check that the momentum accumulators have been updated.
-      self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
-                                         self.evaluate(slot0))
-      self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
-                                         self.evaluate(slot1))
+      self.assertAllCloseAccordingToType(
+          np.array([0.1, 0.1]), self.evaluate(slot0))
+      self.assertAllCloseAccordingToType(
+          np.array([0.01, 0.01]), self.evaluate(slot1))
       # Check that the parameters have been updated.
       self.assertAllCloseAccordingToType(
-          np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
-          self.evaluate(var0))
+          np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), self.evaluate(var0))
       self.assertAllCloseAccordingToType(
           np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
           self.evaluate(var1))
@@ -118,8 +118,8 @@
           ]), self.evaluate(var0))
       self.assertAllCloseAccordingToType(
           np.array([
-              2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
-                  (0.9 * 0.01 + 0.01) * 2.0)
+              2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+              3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
           ]), self.evaluate(var1))
 
   def testBasic(self):
@@ -137,10 +137,12 @@
   def testVariablesAcrossGraphs(self):
     optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5)
     with ops.Graph().as_default():
-      var0 = resource_variable_ops.ResourceVariable(
-          [1.0, 2.0], dtype=dtypes.float32, name="var0")
-      var1 = resource_variable_ops.ResourceVariable(
-          [3.0, 4.0], dtype=dtypes.float32, name="var1")
+      var0 = resource_variable_ops.ResourceVariable([1.0, 2.0],
+                                                    dtype=dtypes.float32,
+                                                    name="var0")
+      var1 = resource_variable_ops.ResourceVariable([3.0, 4.0],
+                                                    dtype=dtypes.float32,
+                                                    name="var1")
       loss = math_ops.reduce_sum(var0 + var1)
       optimizer.minimize(loss)
       optimizer_variables = optimizer.variables()
@@ -149,10 +151,12 @@
       self.assertEqual(2, len(optimizer_variables))
 
     with ops.Graph().as_default():
-      var2 = resource_variable_ops.ResourceVariable(
-          [1.0, 2.0], dtype=dtypes.float32, name="var2")
-      var3 = resource_variable_ops.ResourceVariable(
-          [3.0, 4.0], dtype=dtypes.float32, name="var3")
+      var2 = resource_variable_ops.ResourceVariable([1.0, 2.0],
+                                                    dtype=dtypes.float32,
+                                                    name="var2")
+      var3 = resource_variable_ops.ResourceVariable([3.0, 4.0],
+                                                    dtype=dtypes.float32,
+                                                    name="var3")
       loss = math_ops.reduce_sum(var2 + var3)
       optimizer.minimize(loss)
       optimizer_variables = optimizer.variables()
@@ -181,9 +185,8 @@
           opt_op.run()
           var0_np, accum0_np = self._update_nesterov_momentum_numpy(
               var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
-          var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
-                                                                    accum1_np,
-                                                                    3, 2.0, 0.9)
+          var1_np, accum1_np = self._update_nesterov_momentum_numpy(
+              var1_np, accum1_np, 3, 2.0, 0.9)
           self.assertAllClose(var0_np, self.evaluate(var0))
           self.assertAllClose(var1_np, self.evaluate(var1))
 
@@ -200,32 +203,29 @@
           grads.append(var0_np * 10)
           var0_np, accum0_np = self._update_nesterov_momentum_numpy(
               var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
-          var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
-                                                                    accum1_np,
-                                                                    3, 2.0, 0.9)
+          var1_np, accum1_np = self._update_nesterov_momentum_numpy(
+              var1_np, accum1_np, 3, 2.0, 0.9)
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
         var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
         accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
         accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
         var0 = variables.Variable(var0_np)
         var1 = variables.Variable(var1_np)
-        loss = 5 * var0 * var0 + 3 * var1
         mom_op = momentum_lib.MomentumOptimizer(
             learning_rate=2.0, momentum=0.9, use_nesterov=True)
         x_feed = array_ops.placeholder(dtype)
-        y_feed = ops.IndexedSlices(
-            x_feed, constant_op.constant([0, 1]), constant_op.constant([2]))
-        grads_and_vars = [(y_feed, var0), (constant_op.constant(
-            [3.0, 3.0], dtype=dtype), var1)]
+        y_feed = ops.IndexedSlices(x_feed, constant_op.constant([0, 1]),
+                                   constant_op.constant([2]))
+        grads_and_vars = [(y_feed, var0),
+                          (constant_op.constant([3.0, 3.0], dtype=dtype), var1)]
         opt_update = mom_op.apply_gradients(grads_and_vars)
         self.evaluate(variables.global_variables_initializer())
         for t in range(1, 5):
           opt_update.run(feed_dict={x_feed: grads[t - 1]})
           var0_np, accum0_np = self._update_nesterov_momentum_numpy(
               var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
-          var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
-                                                                    accum1_np,
-                                                                    3, 2.0, 0.9)
+          var1_np, accum1_np = self._update_nesterov_momentum_numpy(
+              var1_np, accum1_np, 3, 2.0, 0.9)
           self.assertAllClose(var0_np, self.evaluate(var0))
           self.assertAllClose(var1_np, self.evaluate(var1))
 
@@ -249,6 +249,7 @@
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
         return pred * pred
+
       # pylint: enable=cell-var-from-loop
 
       opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0)
@@ -464,15 +465,11 @@
         var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
         var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
         grads0 = ops.IndexedSlices(
-            constant_op.constant(
-                [[.1, .1]], dtype=dtype),
-            constant_op.constant([1]),
-            constant_op.constant([4, 2]))
+            constant_op.constant([[.1, .1]], dtype=dtype),
+            constant_op.constant([1]), constant_op.constant([4, 2]))
         grads1 = ops.IndexedSlices(
-            constant_op.constant(
-                [[.01, .01], [.01, .01]], dtype=dtype),
-            constant_op.constant([2, 3]),
-            constant_op.constant([4, 2]))
+            constant_op.constant([[.01, .01], [.01, .01]], dtype=dtype),
+            constant_op.constant([2, 3]), constant_op.constant([4, 2]))
         mom_opt = momentum_lib.MomentumOptimizer(
             learning_rate=2.0, momentum=0.9)
         mom_update = mom_opt.apply_gradients(
diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py
index c3c3570..d4af3fb 100644
--- a/tensorflow/python/training/saving/saveable_object_util.py
+++ b/tensorflow/python/training/saving/saveable_object_util.py
@@ -326,7 +326,7 @@
   Raises:
     ValueError: If the saveable has already been processed.
   """
-  if saveable.op in seen_ops:
+  if saveable.op is not None and saveable.op in seen_ops:
     raise ValueError("The same saveable will be restored with two names: %s" %
                      saveable.name)
   saveables.append(saveable)
diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD
index 5f3f4fd..d48f066 100644
--- a/tensorflow/python/types/BUILD
+++ b/tensorflow/python/types/BUILD
@@ -35,6 +35,7 @@
         ":doc_typealias",
         "//tensorflow/python:tf_export",
         "//third_party/py/numpy",
+        "@typing_extensions_archive//:typing_extensions",
     ],
 )
 
diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py
index bec5aec..b450659 100644
--- a/tensorflow/python/types/core.py
+++ b/tensorflow/python/types/core.py
@@ -18,14 +18,21 @@
 from __future__ import division
 from __future__ import print_function
 
+import sys
 import textwrap
 
 from typing import Union
+
 import numpy as np
 
 from tensorflow.python.types import doc_typealias
 from tensorflow.python.util.tf_export import tf_export
 
+if sys.version_info >= (3, 8):
+  from typing import Protocol  # pylint:disable=g-import-not-at-top
+else:
+  from typing_extensions import Protocol  # pylint:disable=g-import-not-at-top
+
 # TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced.
 # TODO(mdan): Add type annotations.
 
@@ -67,9 +74,24 @@
     pass
 
 
+class TensorProtocol(Protocol):
+  """Protocol type for objects that can be converted to Tensor."""
+
+  def __tf_tensor__(self, dtype=None, name=None):
+    """Converts this object to a Tensor.
+
+    Args:
+      dtype: data type for the returned Tensor
+      name: a name for the operations which create the Tensor
+    Returns:
+      A Tensor.
+    """
+    pass
+
+
 # TODO(rahulkamat): Add missing types that are convertible to Tensor.
-TensorLike = Union[Tensor, int, float, bool, str, complex, tuple, list,
-                   np.ndarray]
+TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, complex,
+                   tuple, list, np.ndarray]
 doc_typealias.document(
     obj=TensorLike,
     doc=textwrap.dedent("""\
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 66f43a3..9f4ae1d 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -335,6 +335,9 @@
   Raises:
     TypeError: The nest is or contains a dict with non-sortable keys.
   """
+  if structure is None:
+    return [None]
+  expand_composites = bool(expand_composites)
   return _pywrap_utils.Flatten(structure, expand_composites)
 
 
@@ -392,6 +395,10 @@
     TypeError: If the two structures differ in the type of sequence in any of
       their substructures. Only possible if `check_types` is `True`.
   """
+  # Convert to bool explicitly as otherwise pybind will not be able# to handle
+  # type mismatch message correctly. See GitHub issue 42329 for details.
+  check_types = bool(check_types)
+  expand_composites = bool(expand_composites)
   try:
     _pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
                                       expand_composites)
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index fb3f210..7f8bb24 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -1218,6 +1218,18 @@
         expected,
     )
 
+  def testInvalidCheckTypes(self):
+    with self.assertRaises((ValueError, TypeError)):
+      nest.assert_same_structure(
+          nest1=array_ops.zeros((1)),
+          nest2=array_ops.ones((1, 1, 1)),
+          check_types=array_ops.ones((2)))
+    with self.assertRaises((ValueError, TypeError)):
+      nest.assert_same_structure(
+          nest1=array_ops.zeros((1)),
+          nest2=array_ops.ones((1, 1, 1)),
+          expand_composites=array_ops.ones((2)))
+
 
 class NestBenchmark(test.Benchmark):
 
diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc
index aa9be63..7f5ff7f 100644
--- a/tensorflow/python/util/tf_stack.cc
+++ b/tensorflow/python/util/tf_stack.cc
@@ -127,6 +127,11 @@
       // For compatibility with the traceback module.
       .def("__eq__", &FrameSummary::operator==)
       .def("__ne__", &FrameSummary::operator!=)
+      .def("__hash__",
+           [](const FrameSummary& self) {
+             return py::hash(
+                 py::make_tuple(self.filename, self.lineno, self.name));
+           })
       .def("__getitem__",
            [](const FrameSummary& self, const py::object& index) -> py::object {
              return py::make_tuple(self.filename, self.lineno, self.name,
diff --git a/tensorflow/python/util/tf_stack_test.py b/tensorflow/python/util/tf_stack_test.py
index dc5a2a2..07dc2d3 100644
--- a/tensorflow/python/util/tf_stack_test.py
+++ b/tensorflow/python/util/tf_stack_test.py
@@ -52,6 +52,17 @@
     another_frame0, _ = tf_stack.extract_stack(limit=2)
     self.assertEqual(frame0, another_frame0)
 
+  def testFrameSummaryEqualityAndHash(self):
+    # Both defined on the same line to produce identical stacks.
+    frame1, frame2 = tf_stack.extract_stack(), tf_stack.extract_stack()
+    self.assertEqual(len(frame1), len(frame2))
+    for f1, f2 in zip(frame1, frame2):
+      self.assertEqual(f1, f2)
+      self.assertEqual(hash(f1), hash(f1))
+      self.assertEqual(hash(f1), hash(f2))
+    self.assertEqual(frame1, frame2)
+    self.assertEqual(hash(tuple(frame1)), hash(tuple(frame2)))
+
 
 def extract_stack(limit=None):
   # Both defined on the same line to produce identical stacks.
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 41b02a3..a341417 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -24,7 +24,7 @@
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/python/lib/core/safe_ptr.h"
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
 
 namespace tensorflow {
 namespace swig {
diff --git a/tensorflow/stream_executor/tpu/tpu_executable_interface.cc b/tensorflow/stream_executor/tpu/tpu_executable_interface.cc
index f260cc1..90ea2dc 100644
--- a/tensorflow/stream_executor/tpu/tpu_executable_interface.cc
+++ b/tensorflow/stream_executor/tpu/tpu_executable_interface.cc
@@ -194,8 +194,9 @@
   // Address of the buffer in TPU memory that is being speculated.
   absl::optional<se::DeviceMemoryBase> cross_program_prefetch_addr;
   if (hlo_module_) {
-    for (const auto& [parameter, index] :
-         hlo_module_->CrossProgramPrefetches()) {
+    for (const auto& prefetch : hlo_module_->CrossProgramPrefetches()) {
+      const auto& parameter = prefetch.first;
+      const auto& index = prefetch.second;
       CHECK_LT(parameter, arguments.size());
       // Ensure the cross program prefetched buffer doesn't alias with any
       // program outputs. If the input and output aliased, the buffer could be
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
index c35745e..9b8b9cd 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc
@@ -26,7 +26,7 @@
 
 namespace {
 TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform,
-                                                  int tries_left = 3) {
+                                                  int tries_left = 5) {
   if (tries_left <= 0) {
     LOG(ERROR) << "Unable to find a TPU platform after exhausting all tries. "
                   "Returning nullptr...";
@@ -60,7 +60,7 @@
   if (!status_or_other_tpu_platforms.ok() &&
       status_or_other_tpu_platforms.status().code() != error::NOT_FOUND) {
     LOG(WARNING) << "Error when getting other TPU platforms: "
-                 << status_or_tpu_platform.status();
+                 << status_or_other_tpu_platforms.status();
     return nullptr;
   }
 
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 36ab1f1..b045688 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -352,12 +352,6 @@
 def tf_openmp_copts():
     return (if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fno-openmp"]))
 
-def tfe_xla_copts():
-    return select({
-        "//tensorflow:with_xla_support": ["-DTENSORFLOW_EAGER_USE_XLA"],
-        "//conditions:default": [],
-    })
-
 def tf_opts_nortti():
     return [
         "-fno-rtti",
@@ -2577,12 +2571,13 @@
         **kwargs
     )
 
-def tf_version_info_genrule(name, out):
+def tf_version_info_genrule(name, out, compatible_with = None):
     # TODO(gunan): Investigate making this action hermetic so we do not need
     # to run it locally.
     _local_genrule(
         name = name,
         out = out,
+        compatible_with = compatible_with,
         exec_tool = "//tensorflow/tools/git:gen_git_source",
         srcs = [
             "@local_config_git//:gen/spec.json",
@@ -2928,3 +2923,6 @@
 
 def tf_grpc_cc_dependency():
     return "//tensorflow:grpc++"
+
+def get_compatible_with_portable():
+    return []
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
index 7397719..7a25de0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt
@@ -25,6 +25,10 @@
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "enable_tensor_float_32_execution"
+    argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "get_device_details"
     argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
   }
@@ -76,4 +80,8 @@
     name: "set_visible_devices"
     argspec: "args=[\'devices\', \'device_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "tensor_float_32_execution_enabled"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index a9f6f06..da08722 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index 168539b..1719c8b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -180,7 +180,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
index 2aff054..d93c018 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -175,7 +175,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
index ed49246..9fba915 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -175,7 +175,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
index 58f8cf2..e3c8c7e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
@@ -14,6 +14,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 4368742..15c0ab5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index 8e9409f..729fdd6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -180,7 +180,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
index fb341cb..88e4ecf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
index d8039ed..89e0718 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
index 912f92f..29b1fba 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
index 3abc6d3..c481aa0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt
index 6257c71..a2b9d31 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-ftrl.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt
index 2ce311d..650ac77 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
index 4855395..50e3da3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt
@@ -12,6 +12,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
index 80a1449..ab8391e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
index 8acfe21..2bad07d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
index 7397719..7a25de0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt
@@ -25,6 +25,10 @@
     argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "enable_tensor_float_32_execution"
+    argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "get_device_details"
     argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None"
   }
@@ -76,4 +80,8 @@
     name: "set_visible_devices"
     argspec: "args=[\'devices\', \'device_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "tensor_float_32_execution_enabled"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index a9f6f06..da08722 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index 168539b..1719c8b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -180,7 +180,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
index 2aff054..d93c018 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -175,7 +175,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
index ed49246..9fba915 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -175,7 +175,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
index 58f8cf2..e3c8c7e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt
@@ -14,6 +14,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 4368742..15c0ab5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -174,7 +174,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 8e9409f..729fdd6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -180,7 +180,7 @@
   }
   member_method {
     name: "compile"
-    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'weighted_metrics\', \'run_eagerly\', \'steps_per_execution\'], varargs=None, keywords=kwargs, defaults=[\'rmsprop\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
index fb341cb..88e4ecf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
index d8039ed..89e0718 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
index 912f92f..29b1fba 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
index 3abc6d3..c481aa0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt
index 6257c71..a2b9d31 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-ftrl.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt
index 2ce311d..650ac77 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
index 4855395..50e3da3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt
@@ -12,6 +12,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
index 80a1449..ab8391e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
index 8acfe21..2bad07d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt
index 06212bd..605cb27 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adadelta.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt
index 09fff05..a436583 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adagrad.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt
index 195ba9e..f874658 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt
index 9859da4..6798187 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adamax.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt
index b338388..15efc6a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-ftrl.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt
index 128f223..00cf3e0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-nadam.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt
index 5633beb..881d15c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-optimizer.pbtxt
@@ -12,6 +12,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt
index db89ecb..661e9cb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-r-m-sprop.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt
index 0cb0205..a14c9a4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.optimizers.-s-g-d.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'property\'>"
   }
   member {
+    name: "global_clipnorm"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "iterations"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/ci_build/builds/docker_test.sh b/tensorflow/tools/ci_build/builds/docker_test.sh
index b2d1dba..eee0a91 100755
--- a/tensorflow/tools/ci_build/builds/docker_test.sh
+++ b/tensorflow/tools/ci_build/builds/docker_test.sh
@@ -122,7 +122,7 @@
 "${DOCKER_IMG_TAG}" \
 /bin/bash -c "tensorflow/tools/ci_build/builds/run_pip_tests.sh && "\
 "tensorflow/tools/ci_build/builds/test_tutorials.sh && "\
-"tensorflow/tools/ci_bukld/builds/integration_tests.sh"
+"tensorflow/tools/ci_build/builds/integration_tests.sh"
 
 RESULT=$?
 
diff --git a/tensorflow/tools/ci_build/builds/libtensorflow.sh b/tensorflow/tools/ci_build/builds/libtensorflow.sh
index a281afe..1ddc57d 100755
--- a/tensorflow/tools/ci_build/builds/libtensorflow.sh
+++ b/tensorflow/tools/ci_build/builds/libtensorflow.sh
@@ -54,7 +54,7 @@
   BAZEL_OPTS="--config=opt --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0"
   export CC_OPT_FLAGS="-mavx -msse4.2"
   if [ "${TF_NEED_CUDA}" == "1" ]; then
-    BAZEL_OPTS="${BAZEL_OPTS} --config=cuda --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
+    BAZEL_OPTS="${BAZEL_OPTS} --config=cuda --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain"
     export TF_NEED_ROCM=0
   fi
   bazel clean --expunge
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index f4961e8..c3daaba 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -109,6 +109,7 @@
 "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
 "^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
 "^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
+"^tensorflow/python/keras/engine/base_layer.py.*\[E1102.*not-callable "\
 "^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
 "^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
 "^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
diff --git a/tensorflow/tools/ci_build/horovod/gpu/nightly.sh b/tensorflow/tools/ci_build/horovod/gpu/nightly.sh
index 0601936..87e5f80 100644
--- a/tensorflow/tools/ci_build/horovod/gpu/nightly.sh
+++ b/tensorflow/tools/ci_build/horovod/gpu/nightly.sh
@@ -63,7 +63,10 @@
 
 # Install Horovod.
 cd ..
+HOROVOD_GPU_OPERATIONS=NCCL
 HOROVOD_WITH_TENSORFLOW=1
+HOROVOD_WITHOUT_PYTORCH=1
+HOROVOD_WITHOUT_MXNET=1
 pip3.7 install horovod[tensorflow] --user
 
 # Install tests.
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_libtensorflow.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_libtensorflow.sh
deleted file mode 100644
index a0e3a7f..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_libtensorflow.sh
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/bin/bash
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-
-# Source the external common scripts.
-source tensorflow/tools/ci_build/release/common.sh
-
-
-# Install latest bazel
-install_bazelisk
-which bazel
-
-# Install realpath
-sudo apt-get install realpath
-
-# Update the version string to nightly
-if [ -n "${IS_NIGHTLY_BUILD}" ]; then
-  ./tensorflow/tools/ci_build/update_version.py --nightly
-fi
-
-./tensorflow/tools/ci_build/linux/libtensorflow.sh
-
-# Copy the nightly version update script
-if [ -n "${IS_NIGHTLY_BUILD}" ]; then
-  cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package
-fi
-
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py35_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py35_nonpip.sh
deleted file mode 100644
index fee64f0..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py35_nonpip.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.5
-# Update bazel
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=0
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.5)
-export TF2_BEHAVIOR=1
-yes "" | "$PYTHON_BIN_PATH" configure.py
-tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py35,-v1only"
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Run tests
-set +e
-bazel test --test_output=errors --config=opt --test_lang_filters=py \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --build_tag_filters="${tag_filters}" \
-  --test_tag_filters="${tag_filters}" -- \
-  ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py35_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py35_pip.sh
deleted file mode 100644
index bdbb7f1..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py35_pip.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.5
-# Update bazel
-install_bazelisk
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="CPU"
-export TF_PYTHON_VERSION='python3.5'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_BUILD_FLAGS="--config=release_cpu_linux"
-export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1"
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py35,-v1only'
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME="tensorflow_cpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py36_nonpip.sh
deleted file mode 100644
index 6b05141..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py36_nonpip.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.6
-# Update bazel
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=0
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.6)
-export TF2_BEHAVIOR=1
-yes "" | "$PYTHON_BIN_PATH" configure.py
-tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py36,-v1only"
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Run tests
-set +e
-bazel test --test_output=errors --config=opt --test_lang_filters=py \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --build_tag_filters="${tag_filters}" \
-  --test_tag_filters="${tag_filters}" -- \
-  ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py36_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py36_pip.sh
deleted file mode 100644
index 6277291..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py36_pip.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.6
-# Update bazel
-install_bazelisk
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="CPU"
-export TF_PYTHON_VERSION='python3.6'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_BUILD_FLAGS="--config=release_cpu_linux"
-export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1"
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py36,-v1only'
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME="tensorflow_cpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py37_nonpip.sh
deleted file mode 100644
index db0c605..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py37_nonpip.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.7
-# Update bazel
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=0
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.7)
-export TF2_BEHAVIOR=1
-yes "" | "$PYTHON_BIN_PATH" configure.py
-tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py37,-v1only"
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Run tests
-set +e
-bazel test --test_output=errors --config=opt --test_lang_filters=py \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --build_tag_filters="${tag_filters}" \
-  --test_tag_filters="${tag_filters}" -- \
-  ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py37_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py37_pip.sh
deleted file mode 100644
index ff88ae4..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py37_pip.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.7
-# Update bazel
-install_bazelisk
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="CPU"
-export TF_PYTHON_VERSION='python3.7'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_BUILD_FLAGS="--config=release_cpu_linux"
-export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1"
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py37,-v1only'
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME="tensorflow_cpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py38_nonpip.sh
deleted file mode 100644
index 36da301..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py38_nonpip.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.8
-# Update bazel
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=0
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.8)
-export TF2_BEHAVIOR=1
-yes "" | "$PYTHON_BIN_PATH" configure.py
-tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py38,-v1only"
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Run tests
-set +e
-bazel test --test_output=errors --config=opt --test_lang_filters=py \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --build_tag_filters="${tag_filters}" \
-  --test_tag_filters="${tag_filters}" -- \
-  ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py38_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py38_pip.sh
deleted file mode 100644
index 52872cf..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/cpu_py38_pip.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.8
-# Update bazel
-install_bazelisk
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="CPU"
-export TF_PYTHON_VERSION='python3.8'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_BUILD_FLAGS="--config=release_cpu_linux"
-export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py --test_output=errors --verbose_failures=true --keep_going --test_env=TF2_BEHAVIOR=1"
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py38,-v1only'
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME="tensorflow_cpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_libtensorflow.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_libtensorflow.sh
deleted file mode 100644
index d294311..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_libtensorflow.sh
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-
-# Source the external common scripts.
-source tensorflow/tools/ci_build/release/common.sh
-
-
-# Install latest bazel
-install_bazelisk
-which bazel
-
-# Install realpath
-sudo apt-get install realpath
-
-export TF_NEED_CUDA=1
-
-# Update the version string to nightly
-if [ -n "${IS_NIGHTLY_BUILD}" ]; then
-  ./tensorflow/tools/ci_build/update_version.py --nightly
-fi
-
-./tensorflow/tools/ci_build/linux/libtensorflow.sh
-
-# Copy the nightly version update script
-if [ -n "${IS_NIGHTLY_BUILD}" ]; then
-  cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package
-fi
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_pip_on_cpu.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_pip_on_cpu.sh
deleted file mode 100755
index 6e67bf2..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_pip_on_cpu.sh
+++ /dev/null
@@ -1,61 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.6
-# Update Bazel to the desired version
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=1
-export TF_CUDA_VERSION=10
-export TF_CUDNN_VERSION=7
-export TF_NEED_TENSORRT=1
-export TENSORRT_INSTALL_PATH=/usr/local/tensorrt
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.6)
-export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib"
-export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_37,sm_52,sm_60,sm_61,compute_70
-
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-########################
-## Build GPU pip package
-########################
-bazel build --config=opt \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  tensorflow/tools/pip_package:build_pip_package
-
-# Set TF nightly flag so we get the proper version of estimator
-if [[ "$IS_NIGHTLY" == 1 ]]; then
-  NIGHTLY_FLAG="--nightly_flag"
-fi
-
-PIP_WHL_DIR=whl
-mkdir -p ${PIP_WHL_DIR}
-PIP_WHL_DIR=$(readlink -f ${PIP_WHL_DIR})  # Get absolute path
-bazel-bin/tensorflow/tools/pip_package/build_pip_package "${PIP_WHL_DIR}" "${NIGHTLY_FLAG}"
-WHL_PATH=$(ls "${PIP_WHL_DIR}"/*.whl)
-
-cp "${WHL_PATH}" "$(pwd)"/.
-chmod +x tensorflow/tools/ci_build/builds/docker_cpu_pip.sh
-docker run -e "BAZEL_VERSION=${BAZEL_VERSION}" -e "CI_BUILD_USER=$(id -u -n)" -e "CI_BUILD_UID=$(id -u)"  -e "CI_BUILD_GROUP=$(id -g -n)" -e "CI_BUILD_GID=$(id -g)"  -e "CI_BUILD_HOME=/bazel_pip" -v "$(pwd)":/bazel_pip tensorflow/tensorflow:devel "./bazel_pip/tensorflow/tools/ci_build/builds/with_the_same_user" "./bazel_pip/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh"
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py35_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py35_nonpip.sh
deleted file mode 100644
index 3e91bf7..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py35_nonpip.sh
+++ /dev/null
@@ -1,61 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.5
-# Update bazel
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=1
-export TF_CUDA_VERSION=11
-export TF_CUDNN_VERSION=8
-export TF_NEED_TENSORRT=1
-export TENSORRT_INSTALL_PATH=/usr/local/tensorrt
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.5)
-export TF2_BEHAVIOR=1
-export PROJECT_NAME="tensorflow_gpu"
-export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib"
-export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_37,sm_52,sm_60,sm_61,compute_70
-
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35"
-
-set +e
-ls /usr/include/cud*
-bazel test --config=cuda --config=opt -s \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --test_lang_filters=py \
-  --test_tag_filters=${tag_filters} \
-  --build_tag_filters=${tag_filters} \
-  --test_timeout="300,450,1200,3600" --local_test_jobs=4 \
-  --test_output=errors --verbose_failures=true \
-  --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
-  -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py35_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py35_pip.sh
deleted file mode 100644
index 2a5c550..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py35_pip.sh
+++ /dev/null
@@ -1,55 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.5
-# Update bazel
-install_bazelisk
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="GPU"
-export TF_PYTHON_VERSION='python3.5'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py35'
-export TF_BUILD_FLAGS="--config=release_gpu_linux "
-export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \
---distinct_host_configuration=false \
---action_env=TF_CUDA_VERSION=10 --action_env=TF_CUDNN_VERSION=7 --test_env=TF2_BEHAVIOR=1 \
---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \
---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \
---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute "
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME="tensorflow_gpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-# To build both tensorflow and tensorflow-gpu pip packages
-export TF_BUILD_BOTH_GPU_PACKAGES=1
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py36_nonpip.sh
deleted file mode 100644
index 70038a8..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py36_nonpip.sh
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.6
-# Update bazel
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=1
-export TF_CUDA_VERSION=10
-export TF_CUDNN_VERSION=7
-export TF_NEED_TENSORRT=1
-export TENSORRT_INSTALL_PATH=/usr/local/tensorrt
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.6)
-export TF2_BEHAVIOR=1
-export PROJECT_NAME="tensorflow_gpu"
-export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib"
-export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_37,sm_52,sm_60,sm_61,compute_70
-
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36"
-
-set +e
-bazel test --config=cuda --config=opt \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --test_lang_filters=py \
-  --test_tag_filters=${tag_filters} \
-  --build_tag_filters=${tag_filters} \
-  --test_timeout="300,450,1200,3600" --local_test_jobs=4 \
-  --test_output=errors --verbose_failures=true --keep_going \
-  --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
-  -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py36_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py36_pip.sh
deleted file mode 100644
index 9aa724c..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py36_pip.sh
+++ /dev/null
@@ -1,55 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.6
-# Update bazel
-install_bazelisk
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="GPU"
-export TF_PYTHON_VERSION='python3.6'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py36'
-export TF_BUILD_FLAGS="--config=release_gpu_linux "
-export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \
---distinct_host_configuration=false \
---action_env=TF_CUDA_VERSION=10 --action_env=TF_CUDNN_VERSION=7 --test_env=TF2_BEHAVIOR=1 \
---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \
---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \
---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute "
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME=="tensorflow_gpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-# To build both tensorflow and tensorflow-gpu pip packages
-export TF_BUILD_BOTH_GPU_PACKAGES=1
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py37_nonpip.sh
deleted file mode 100644
index 225b2cf..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py37_nonpip.sh
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.7
-# Update bazel
-install_bazelisk
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=1
-export TF_CUDA_VERSION=10
-export TF_CUDNN_VERSION=7
-export TF_NEED_TENSORRT=1
-export TENSORRT_INSTALL_PATH=/usr/local/tensorrt
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.7)
-export TF2_BEHAVIOR=1
-export PROJECT_NAME="tensorflow_gpu"
-export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib"
-export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_37,sm_52,sm_60,sm_61,compute_70
-
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37"
-
-set +e
-bazel test --config=cuda --config=opt \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --test_lang_filters=py \
-  --build_tag_filters=${tag_filters} \
-  --test_tag_filters=${tag_filters} \
-  --test_timeout="300,450,1200,3600" --local_test_jobs=4 \
-  --test_output=errors --verbose_failures=true --keep_going \
-  --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
-  -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py37_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py37_pip.sh
deleted file mode 100644
index d884a48..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py37_pip.sh
+++ /dev/null
@@ -1,65 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.7
-# Update bazel
-install_bazelisk
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="GPU"
-export TF_PYTHON_VERSION='python3.7'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py37'
-# TODO (pkanwar): Revert this CL (cl/326069644) once the cuda 11 migration is complete.
-export TF_BUILD_FLAGS="--config=release_common "
-export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \
---distinct_host_configuration=false \
---action_env=TF_CUDA_VERSION=11 --action_env=TF_CUDNN_VERSION=8 --test_env=TF2_BEHAVIOR=1 \
---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \
---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \
---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
---config=cuda \
---config=tensorrt \
---action_env=CUDA_TOOLKIT_PATH=/usr/local/cuda-11.0 --action_env=TF_NEED_TENSORRT=1 \
---action_env=TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_37,sm_52,sm_60,sm_61,compute_70 \
---action_env=TENSORRT_INSTALL_PATH=/usr/local/tensorrt \
---action_env=LD_LIBRARY_PATH=/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib \
---action_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc-5 \
---config=avx_linux \
---crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain"
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME=="tensorflow_gpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-# To build both tensorflow and tensorflow-gpu pip packages
-export TF_BUILD_BOTH_GPU_PACKAGES=1
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py38_nonpip.sh
deleted file mode 100644
index f7678b7..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py38_nonpip.sh
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.8
-# Update bazel
-update_bazel_linux
-
-# Run configure.
-export TF_NEED_GCP=1
-export TF_NEED_HDFS=1
-export TF_NEED_S3=1
-export TF_NEED_CUDA=1
-export TF_CUDA_VERSION=10
-export TF_CUDNN_VERSION=7
-export TF_NEED_TENSORRT=1
-export TENSORRT_INSTALL_PATH=/usr/local/tensorrt
-export CC_OPT_FLAGS='-mavx'
-export PYTHON_BIN_PATH=$(which python3.8)
-export TF2_BEHAVIOR=1
-export PROJECT_NAME="tensorflow_gpu"
-export LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$TENSORRT_INSTALL_PATH/lib"
-export TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_37,sm_52,sm_60,sm_61,compute_70
-
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-tag_filters="gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py38"
-
-test +e
-bazel test --config=cuda --config=opt \
-  --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain \
-  --linkopt=-lrt \
-  --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
-  --test_lang_filters=py \
-  --build_tag_filters=${tag_filters} \
-  --test_tag_filters=${tag_filters} \
-  --test_timeout="300,450,1200,3600" --local_test_jobs=4 \
-  --test_output=errors --verbose_failures=true --keep_going \
-  --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
-  -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
-test_xml_summary_exit
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py38_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py38_pip.sh
deleted file mode 100644
index d8838e7..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/gpu_py38_pip.sh
+++ /dev/null
@@ -1,55 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-set -x
-
-source tensorflow/tools/ci_build/release/common.sh
-
-install_ubuntu_16_pip_deps pip3.8
-# Update bazel
-update_bazel_linux
-
-# Export required variables for running pip.sh
-export OS_TYPE="UBUNTU"
-export CONTAINER_TYPE="GPU"
-export TF_PYTHON_VERSION='python3.8'
-
-# Run configure.
-export PYTHON_BIN_PATH=$(which ${TF_PYTHON_VERSION})
-yes "" | "$PYTHON_BIN_PATH" configure.py
-
-# Get the default test targets for bazel.
-source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh
-
-# Export optional variables for running pip.sh
-export TF_TEST_FILTER_TAGS='gpu,requires-gpu,-no_gpu,-no_oss,-oss_serial,-no_oss_py38'
-export TF_BUILD_FLAGS="--config=release_gpu_linux "
-export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filters=${TF_TEST_FILTER_TAGS} \
---distinct_host_configuration=false \
---action_env=TF_CUDA_VERSION=10 --action_env=TF_CUDNN_VERSION=7 --test_env=TF2_BEHAVIOR=1 \
---config=cuda --test_output=errors --local_test_jobs=4 --test_lang_filters=py \
---verbose_failures=true --keep_going --define=no_tensorflow_py_deps=true \
---run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute "
-export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... "
-export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean"
-#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo.
-export TF_PROJECT_NAME=="tensorflow_gpu"
-export TF_PIP_TEST_ROOT="pip_test"
-
-# To build both tensorflow and tensorflow-gpu pip packages
-export TF_BUILD_BOTH_GPU_PACKAGES=1
-
-./tensorflow/tools/ci_build/builds/pip_new.sh
diff --git a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/sanity.sh b/tensorflow/tools/ci_build/rel/ubuntu_cuda11/sanity.sh
deleted file mode 100644
index 4fc600d..0000000
--- a/tensorflow/tools/ci_build/rel/ubuntu_cuda11/sanity.sh
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-set -e
-
-# Install latest bazel
-source tensorflow/tools/ci_build/release/common.sh
-install_bazelisk
-which bazel
-
-# We need py3 lint
-sudo pip3 install pep8
-
-# TODO(gunan): figure out why we get stuck with later versions of pylint.
-# Install pylint.
-sudo python3 -m pip install setuptools --upgrade
-sudo python2 -m pip install pylint==1.6.4
-sudo python3 -m pip install pylint==1.6.4
-
-# TODO(yifeif): print pylint version for debug. remove later.
-python3 -m pylint --version
-
-# Run tensorflow sanity checks.
-tensorflow/tools/ci_build/ci_sanity.sh
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/common_win_cuda11.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/common_win_cuda11.bat
deleted file mode 100644
index 81f2c86..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/common_win_cuda11.bat
+++ /dev/null
@@ -1,24 +0,0 @@
-:: Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-echo on
-
-SET TF_CUDA_VERSION=11.0
-SET TF_CUDNN_VERSION=8
-
-REM TODO(sanjoy): This script should be removed once common_win.bat
-REM defaults to CUDA 11.
-
-CALL tensorflow\tools\ci_build\release\common_win.bat
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_libtensorflow.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_libtensorflow.bat
deleted file mode 100644
index e583c5e..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_libtensorflow.bat
+++ /dev/null
@@ -1,20 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-call tensorflow\tools\ci_build\windows\cpu\bazel\run_libtensorflow.bat || exit /b 1
-
-copy lib_package %TF_ARTIFACTS_DIR%\lib_package
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py35.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py35.bat
deleted file mode 100644
index c87dac6..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py35.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python35
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu"
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py36.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py36.bat
deleted file mode 100644
index df29b8e..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py36.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python36
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu"
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py37.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py37.bat
deleted file mode 100644
index 3ed6fe3..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py37.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python37
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu"
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py38.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py38.bat
deleted file mode 100644
index 71d68e6..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/cpu_py38.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python38
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu"
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_libtensorflow.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_libtensorflow.bat
deleted file mode 100644
index bd15e83..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_libtensorflow.bat
+++ /dev/null
@@ -1,20 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-call tensorflow\tools\ci_build\windows\gpu\bazel\run_libtensorflow.bat || exit /b
-
-copy lib_package %TF_ARTIFACTS_DIR%\lib_package
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_pip_on_cpu.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_pip_on_cpu.bat
deleted file mode 100644
index 207359b..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_pip_on_cpu.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python36
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-call tensorflow\tools\ci_build\windows\integration\gpu_pip_on_cpu\run.bat
-
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py35.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py35.bat
deleted file mode 100644
index d8ba563..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py35.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python35
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py36.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py36.bat
deleted file mode 100644
index 58cf423..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py36.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python36
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py37.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py37.bat
deleted file mode 100644
index 60c6eb6..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py37.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python37
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly
diff --git a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py38.bat b/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py38.bat
deleted file mode 100644
index da909ba..0000000
--- a/tensorflow/tools/ci_build/rel/windows_cuda11/gpu_py38.bat
+++ /dev/null
@@ -1,21 +0,0 @@
-:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-::
-:: Licensed under the Apache License, Version 2.0 (the "License");
-:: you may not use this file except in compliance with the License.
-:: You may obtain a copy of the License at
-::
-::     http://www.apache.org/licenses/LICENSE-2.0
-::
-:: Unless required by applicable law or agreed to in writing, software
-:: distributed under the License is distributed on an "AS IS" BASIS,
-:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-:: See the License for the specific language governing permissions and
-:: limitations under the License.
-:: =============================================================================
-
-SET PYTHON_DIRECTORY=Python38
-
-CALL tensorflow\tools\ci_build\rel\windows_cuda11\common_win_cuda11.bat
-
-:: TODO(angerson) Set this based on some env param before merging with nightly
-call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly
diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh
index c3b5bd9..2c7a425 100644
--- a/tensorflow/tools/ci_build/release/common.sh
+++ b/tensorflow/tools/ci_build/release/common.sh
@@ -142,6 +142,7 @@
   ${SUDO_CMD} ${PIP_CMD} install portpicker
   ${SUDO_CMD} ${PIP_CMD} install scipy
   ${SUDO_CMD} ${PIP_CMD} install scikit-learn
+  ${SUDO_CMD} ${PIP_CMD} install typing_extensions
   ${SUDO_CMD} ${PIP_CMD} install --upgrade tb-nightly
   ${PIP_CMD} install --user --upgrade flatbuffers
   ${PIP_CMD} install --user --upgrade attrs
@@ -178,6 +179,7 @@
   "${PIP_CMD}" install portpicker --user
   "${PIP_CMD}" install scipy --user
   "${PIP_CMD}" install scikit-learn --user
+  "${PIP_CMD}" install typing_extensions --user
   "${PIP_CMD}" install PyYAML==3.13 --user
   # b/156523241
   "${PIP_CMD}" install --force-reinstall --user --upgrade tf-estimator-nightly
@@ -220,6 +222,7 @@
   ${SUDO_CMD} ${PIP_CMD} install numpy==1.16.0
   ${SUDO_CMD} ${PIP_CMD} install gast==0.3.3
   ${SUDO_CMD} ${PIP_CMD} install h5py==2.10.0
+  ${SUDO_CMD} ${PIP_CMD} install typing_extensions
   ${SUDO_CMD} ${PIP_CMD} install --upgrade grpcio
   ${SUDO_CMD} ${PIP_CMD} install --upgrade tb-nightly
   ${PIP_CMD} install --user --upgrade flatbuffers
diff --git a/tensorflow/tools/ci_build/release/common_win.bat b/tensorflow/tools/ci_build/release/common_win.bat
index 6b9b533..267ea67 100644
--- a/tensorflow/tools/ci_build/release/common_win.bat
+++ b/tensorflow/tools/ci_build/release/common_win.bat
@@ -57,13 +57,14 @@
 @REM handle this case.
 %PIP_EXE% install gast==0.3.3
 %PIP_EXE% install astunparse==1.6.3
+%PIP_EXE% install typing_extensions
 
 :: Set cuda related environment variables. If we are not using CUDA, these are not used.
 IF NOT DEFINED TF_CUDA_VERSION (
-  SET TF_CUDA_VERSION=10.1
+  SET TF_CUDA_VERSION=11.0
 )
 IF NOT DEFINED TF_CUDNN_VERSION (
-  SET TF_CUDNN_VERSION=7
+  SET TF_CUDNN_VERSION=8
 )
 SET TF_CUDA_COMPUTE_CAPABILITIES=sm_35,sm_37,sm_52,sm_60,sm_61,compute_70
 SET CUDA_TOOLKIT_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v%TF_CUDA_VERSION%
diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
index 23016f7..e767a0c 100644
--- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
@@ -55,8 +55,8 @@
 export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH"
 
 # Setting default values to CUDA related environment variables
-export TF_CUDA_VERSION=${TF_CUDA_VERSION:-10.1}
-export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7}
+export TF_CUDA_VERSION=${TF_CUDA_VERSION:-11.0}
+export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-8}
 export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-6.0}
 export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"}
 export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"}
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 93305a0..ed4dada 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -337,8 +337,8 @@
 tensorflow::ProfilerSession::~ProfilerSession
 
 [profiler_server_impl] # profiler
-tensorflow::ProfilerServer::StartProfilerServer
-tensorflow::ProfilerServer::~ProfilerServer
+tensorflow::profiler::ProfilerServer::StartProfilerServer
+tensorflow::profiler::ProfilerServer::~ProfilerServer
 
 [profiler_client_impl] # profiler
 tensorflow::profiler::ProfileGrpc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
index b8bbbbd..83e01bd 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
@@ -22,37 +22,34 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.2.39-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
-        libcublas-dev=10.2.1.243-1 \
+        libcublas-${CUDA/./-} \
+        libcublas-dev-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
         cuda-nvrtc-dev-${CUDA/./-} \
         cuda-cudart-dev-${CUDA/./-} \
-        cuda-cufft-dev-${CUDA/./-} \
-        cuda-curand-dev-${CUDA/./-} \
-        cuda-cusolver-dev-${CUDA/./-} \
-        cuda-cusparse-dev-${CUDA/./-} \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
-        libcudnn7-dev=${CUDNN}+cuda${CUDA} \
+        libcufft-dev-${CUDA/./-} \
+        libcurand-dev-${CUDA/./-} \
+        libcusolver-dev-${CUDA/./-} \
+        libcusparse-dev-${CUDA/./-} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
+        libcudnn8-dev=${CUDNN}+cuda${CUDA} \
         libcurl3-dev \
         libfreetype6-dev \
         libhdf5-serial-dev \
@@ -67,7 +64,7 @@
         git \
         && \
     find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
-    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v8.a
 
 # Install TensorRT if not building for PowerPC
 RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
index 81d50dc..60a3e57 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
@@ -22,37 +22,34 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.2.39-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
-        libcublas-dev=10.2.1.243-1 \
+        libcublas-${CUDA/./-} \
+        libcublas-dev-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
         cuda-nvrtc-dev-${CUDA/./-} \
         cuda-cudart-dev-${CUDA/./-} \
-        cuda-cufft-dev-${CUDA/./-} \
-        cuda-curand-dev-${CUDA/./-} \
-        cuda-cusolver-dev-${CUDA/./-} \
-        cuda-cusparse-dev-${CUDA/./-} \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
-        libcudnn7-dev=${CUDNN}+cuda${CUDA} \
+        libcufft-dev-${CUDA/./-} \
+        libcurand-dev-${CUDA/./-} \
+        libcusolver-dev-${CUDA/./-} \
+        libcusparse-dev-${CUDA/./-} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
+        libcudnn8-dev=${CUDNN}+cuda${CUDA} \
         libcurl3-dev \
         libfreetype6-dev \
         libhdf5-serial-dev \
@@ -67,7 +64,7 @@
         git \
         && \
     find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
-    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v8.a
 
 # Install TensorRT if not building for PowerPC
 RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
index d4d913c..911678b 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
@@ -22,17 +22,17 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.2.39-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
@@ -40,17 +40,14 @@
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
+        libcublas-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
-        cuda-cufft-${CUDA/./-} \
-        cuda-curand-${CUDA/./-} \
-        cuda-cusolver-${CUDA/./-} \
-        cuda-cusparse-${CUDA/./-} \
+        libcufft-${CUDA/./-} \
+        libcurand-${CUDA/./-} \
+        libcusolver-${CUDA/./-} \
+        libcusparse-${CUDA/./-} \
         curl \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
         libfreetype6-dev \
         libhdf5-serial-dev \
         libzmq3-dev \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
index f563f2f..228513d 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
@@ -22,17 +22,17 @@
 ARG UBUNTU_VERSION=18.04
 
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.2.39-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
@@ -40,17 +40,14 @@
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
+        libcublas-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
-        cuda-cufft-${CUDA/./-} \
-        cuda-curand-${CUDA/./-} \
-        cuda-cusolver-${CUDA/./-} \
-        cuda-cusparse-${CUDA/./-} \
+        libcufft-${CUDA/./-} \
+        libcurand-${CUDA/./-} \
+        libcusolver-${CUDA/./-} \
+        libcusparse-${CUDA/./-} \
         curl \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
         libfreetype6-dev \
         libhdf5-serial-dev \
         libzmq3-dev \
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
index 5b4b2b7..ed310f3 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
@@ -1,35 +1,32 @@
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.2.39-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
-        libcublas-dev=10.2.1.243-1 \
+        libcublas-${CUDA/./-} \
+        libcublas-dev-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
         cuda-nvrtc-dev-${CUDA/./-} \
         cuda-cudart-dev-${CUDA/./-} \
-        cuda-cufft-dev-${CUDA/./-} \
-        cuda-curand-dev-${CUDA/./-} \
-        cuda-cusolver-dev-${CUDA/./-} \
-        cuda-cusparse-dev-${CUDA/./-} \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
-        libcudnn7-dev=${CUDNN}+cuda${CUDA} \
+        libcufft-dev-${CUDA/./-} \
+        libcurand-dev-${CUDA/./-} \
+        libcusolver-dev-${CUDA/./-} \
+        libcusparse-dev-${CUDA/./-} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
+        libcudnn8-dev=${CUDNN}+cuda${CUDA} \
         libcurl3-dev \
         libfreetype6-dev \
         libhdf5-serial-dev \
@@ -44,7 +41,7 @@
         git \
         && \
     find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
-    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+    rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v8.a
 
 # Install TensorRT if not building for PowerPC
 RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
index 555caf0..b2a7b46 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
@@ -1,15 +1,15 @@
 ARG ARCH=
-ARG CUDA=10.1
+ARG CUDA=11.0
 FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
 # ARCH and CUDA are specified again because the FROM directive resets ARGs
 # (but their default value is retained if set previously)
 ARG ARCH
 ARG CUDA
-ARG CUDNN=7.6.4.38-1
-ARG CUDNN_MAJOR_VERSION=7
+ARG CUDNN=8.0.2.39-1
+ARG CUDNN_MAJOR_VERSION=8
 ARG LIB_DIR_PREFIX=x86_64
-ARG LIBNVINFER=6.0.1-1
-ARG LIBNVINFER_MAJOR_VERSION=6
+ARG LIBNVINFER=7.1.3-1
+ARG LIBNVINFER_MAJOR_VERSION=7
 
 # Needed for string substitution
 SHELL ["/bin/bash", "-c"]
@@ -17,17 +17,14 @@
 RUN apt-get update && apt-get install -y --no-install-recommends \
         build-essential \
         cuda-command-line-tools-${CUDA/./-} \
-        # There appears to be a regression in libcublas10=10.2.2.89-1 which
-        # prevents cublas from initializing in TF. See
-        # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
-        libcublas10=10.2.1.243-1 \ 
+        libcublas-${CUDA/./-} \
         cuda-nvrtc-${CUDA/./-} \
-        cuda-cufft-${CUDA/./-} \
-        cuda-curand-${CUDA/./-} \
-        cuda-cusolver-${CUDA/./-} \
-        cuda-cusparse-${CUDA/./-} \
+        libcufft-${CUDA/./-} \
+        libcurand-${CUDA/./-} \
+        libcusolver-${CUDA/./-} \
+        libcusparse-${CUDA/./-} \
         curl \
-        libcudnn7=${CUDNN}+cuda${CUDA} \
+        libcudnn8=${CUDNN}+cuda${CUDA} \
         libfreetype6-dev \
         libhdf5-serial-dev \
         libzmq3-dev \
diff --git a/tensorflow/tools/docs/generate2.py b/tensorflow/tools/docs/generate2.py
index 66715ca..a7fa5d8 100644
--- a/tensorflow/tools/docs/generate2.py
+++ b/tensorflow/tools/docs/generate2.py
@@ -220,6 +220,9 @@
 
   doc_generator.build(output_dir)
 
+  if gen_report:
+    return
+
   out_path = pathlib.Path(output_dir)
 
   expected_path_contents = {
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 74585cb..71b1727 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -218,6 +218,7 @@
         "@sobol_data//:LICENSE",
         "@tblib_archive//:LICENSE",
         "@termcolor_archive//:COPYING.txt",
+        "@typing_extensions_archive//:LICENSE",
         "@zlib//:zlib.h",
         "@clog//:LICENSE",
         "@cpuinfo//:LICENSE",
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index d2002b5..9916b15 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -81,7 +81,7 @@
     "//tensorflow/python:tf_optimizer",
     "//tensorflow/python:compare_test_proto_py",
     "//tensorflow/core:image_testdata",
-    "//tensorflow/core:lmdb_testdata",
+    "//tensorflow/core/lib/lmdb:lmdb_testdata",
     "//tensorflow/core/kernels/cloud:bigquery_reader_ops",
     "//tensorflow/python/debug:grpc_tensorflow_server.par",
     "//tensorflow/python/feature_column:vocabulary_testdata",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 54021af..4d72eca 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -67,6 +67,7 @@
     'tensorboard >= 2.3.0, < 3',
     'tensorflow_estimator >= 2.3.0, < 2.4.0',
     'termcolor >= 1.1.0',
+    'typing_extensions >= 3.7.4.2',
     'wrapt >= 1.11.1',
     'wheel >= 0.26',
     'six >= 1.12.0',
@@ -259,6 +260,7 @@
     version=_VERSION.replace('-', ''),
     description=DOCLINES[0],
     long_description='\n'.join(DOCLINES[2:]),
+    long_description_content_type="text/markdown",
     url='https://www.tensorflow.org/',
     download_url='https://github.com/tensorflow/tensorflow/tags',
     author='Google Inc.',
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 24b811f..2d8747e 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -515,6 +515,29 @@
     )
 
     tf_http_archive(
+        name = "typing_extensions_archive",
+        build_file = clean_dep("//third_party:typing_extensions.BUILD"),
+        sha256 = "79ee589a3caca649a9bfd2a8de4709837400dfa00b6cc81962a1e6a1815969ae",
+        strip_prefix = "typing_extensions-3.7.4.2/src_py3",
+        system_build_file = clean_dep("//third_party/systemlibs:typing_extensions.BUILD"),
+        urls = [
+            "http://mirror.tensorflow.org/files.pythonhosted.org/packages/6a/28/d32852f2af6b5ead85d396249d5bdf450833f3a69896d76eb480d9c5e406/typing_extensions-3.7.4.2.tar.gz",
+            "https://files.pythonhosted.org/packages/6a/28/d32852f2af6b5ead85d396249d5bdf450833f3a69896d76eb480d9c5e406/typing_extensions-3.7.4.2.tar.gz",
+        ],
+    )
+
+    filegroup_external(
+        name = "typing_extensions_license",
+        licenses = ["notice"],  # PSFL
+        sha256_urls = {
+            "ff17ce94e102024deb68773eb1cc74ca76da4e658f373531f0ac22d68a6bb1ad": [
+                "http://mirror.tensorflow.org/raw.githubusercontent.com/python/typing/master/typing_extensions/LICENSE",
+                "https://raw.githubusercontent.com/python/typing/master/typing_extensions/LICENSE",
+            ],
+        },
+    )
+
+    tf_http_archive(
         name = "opt_einsum_archive",
         build_file = clean_dep("//third_party:opt_einsum.BUILD"),
         sha256 = "d3d464b4da7ef09e444c30e4003a27def37f85ff10ff2671e5f7d7813adac35b",
@@ -699,8 +722,8 @@
     )
 
     # Check out LLVM and MLIR from llvm-project.
-    LLVM_COMMIT = "e75bc5c791e0e8dbe79f7453e55af9e8d03c9cc0"
-    LLVM_SHA256 = "9c22f59d50853329cd0105ecb95256ad345313372ddda593030cd81b7c72e657"
+    LLVM_COMMIT = "1d01fc100bb5bef5f5eaf92520b2e52f64ee1d6e"
+    LLVM_SHA256 = "a8a2503b98945e91e55df114a7b3739c88c900cf14839fba2221fd1f8cfc1d5a"
     LLVM_URLS = [
         "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
         "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl
index 1fbe629..d7099e5 100644
--- a/third_party/flatbuffers/build_defs.bzl
+++ b/third_party/flatbuffers/build_defs.bzl
@@ -24,6 +24,7 @@
         out_prefix = "",
         includes = [],
         include_paths = [],
+        compatible_with = [],
         flatc_args = DEFAULT_FLATC_ARGS,
         reflection_name = "",
         reflection_visibility = None,
@@ -43,6 +44,8 @@
           single source targets. Usually is a directory name.
       includes: Optional, list of filegroups of schemas that the srcs depend on.
       include_paths: Optional, list of paths the includes files can be found in.
+      compatible_with: Optional, passed to genrule for environments this rule
+          can be built for.
       flatc_args: Optional, list of additional arguments to pass to flatc.
       reflection_name: Optional, if set this will generate the flatbuffer
         reflection binaries for the schemas.
@@ -72,6 +75,7 @@
         srcs = srcs,
         outs = outs,
         output_to_bindir = output_to_bindir,
+        compatible_with = compatible_with,
         tools = includes + [flatc_path],
         cmd = genrule_cmd,
         message = "Generating flatbuffer files for %s:" % (name),
@@ -97,6 +101,7 @@
             srcs = srcs,
             outs = reflection_outs,
             output_to_bindir = output_to_bindir,
+            compatible_with = compatible_with,
             tools = includes + [flatc_path],
             cmd = reflection_genrule_cmd,
             message = "Generating flatbuffer reflection binary for %s:" % (name),
@@ -111,6 +116,7 @@
         #         native.FilesetEntry(files = reflection_outs),
         #     ],
         #     visibility = reflection_visibility,
+        #     compatible_with = compatible_with,
         # )
 
 def flatbuffer_cc_library(
@@ -120,6 +126,7 @@
         out_prefix = "",
         includes = [],
         include_paths = [],
+        compatible_with = [],
         flatc_args = DEFAULT_FLATC_ARGS,
         visibility = None,
         srcs_filegroup_visibility = None,
@@ -175,6 +182,8 @@
       includes: Optional, list of filegroups of schemas that the srcs depend on.
           ** SEE REMARKS BELOW **
       include_paths: Optional, list of paths the includes files can be found in.
+      compatible_with: Optional, passed to genrule for environments this rule
+          can be built for
       flatc_args: Optional list of additional arguments to pass to flatc
           (e.g. --gen-mutable).
       visibility: The visibility of the generated cc_library. By default, use the
@@ -198,6 +207,7 @@
         out_prefix = out_prefix,
         includes = includes,
         include_paths = include_paths,
+        compatible_with = compatible_with,
         flatc_args = flatc_args,
         reflection_name = reflection_name,
         reflection_visibility = visibility,
@@ -215,6 +225,7 @@
         includes = ["."],
         linkstatic = 1,
         visibility = visibility,
+        compatible_with = compatible_with,
     )
 
     # A filegroup for the `srcs`. That is, all the schema files for this
@@ -223,6 +234,7 @@
         name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
         srcs = srcs,
         visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility,
+        compatible_with = compatible_with,
     )
 
 # Custom provider to track dependencies transitively.
diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py
index 091cd32..80f3430 100644
--- a/third_party/gpus/find_cuda_config.py
+++ b/third_party/gpus/find_cuda_config.py
@@ -176,6 +176,7 @@
       "include/*-linux-gnu",
       "extras/CUPTI/include",
       "include/cuda/CUPTI",
+      "local/cuda/extras/CUPTI/include",
   ]
 
 
@@ -188,6 +189,8 @@
       "lib/*-linux-gnu",
       "lib/x64",
       "extras/CUPTI/*",
+      "local/cuda/lib64",
+      "local/cuda/extras/CUPTI/lib64",
   ]
 
 
@@ -268,12 +271,14 @@
   nvcc_path, nvcc_version = _find_versioned_file(base_paths, [
       "",
       "bin",
+      "local/cuda/bin",
   ], nvcc_name, cuda_version, get_nvcc_version)
 
   nvvm_path = _find_file(base_paths, [
       "nvvm/libdevice",
       "share/cuda",
       "lib/nvidia-cuda-toolkit/libdevice",
+      "local/cuda/nvvm/libdevice",
   ], "libdevice*.10.bc")
 
   cupti_header_path = _find_file(base_paths, _header_paths(), "cupti.h")
diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD
index 3d5717b..a63c38c 100644
--- a/third_party/llvm/llvm.autogenerated.BUILD
+++ b/third_party/llvm/llvm.autogenerated.BUILD
@@ -1759,6 +1759,7 @@
         "lib/CodeGen/*.c",
         "lib/CodeGen/*.cpp",
         "lib/CodeGen/*.inc",
+        "lib/CodeGen/LiveDebugValues/*.cpp",
         "lib/CodeGen/*.h",
     ]),
     hdrs = glob([
@@ -3323,6 +3324,7 @@
         ":IPO",
         ":InstCombine",
         ":Instrumentation",
+        ":ObjCARC",
         ":Scalar",
         ":Support",
         ":Target",
diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl
index 48b986b..dcbaab9 100644
--- a/third_party/llvm/llvm.bzl
+++ b/third_party/llvm/llvm.bzl
@@ -190,6 +190,7 @@
     "HAVE_PTHREAD_H": 1,
     "HAVE_SIGNAL_H": 1,
     "HAVE_STDINT_H": 1,
+    "HAVE_SYSEXITS_H": 1,
     "HAVE_SYS_IOCTL_H": 1,
     "HAVE_SYS_MMAN_H": 1,
     "HAVE_SYS_PARAM_H": 1,
diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD
index 445b547..8c1a5eb 100644
--- a/third_party/mkl_dnn/mkldnn_v1.BUILD
+++ b/third_party/mkl_dnn/mkldnn_v1.BUILD
@@ -75,8 +75,8 @@
         "src/cpu/**/*.cpp",
         "src/cpu/**/*.hpp",
         "src/cpu/xbyak/*.h",
-        "src/cpu/jit_utils/jitprofiling/*.c",
-        "src/cpu/jit_utils/jitprofiling/*.h",
+        "src/cpu/x64/jit_utils/jitprofiling/*.c",
+        "src/cpu/x64/jit_utils/jitprofiling/*.h",
     ]) + [
         ":dnnl_config_h",
         ":dnnl_version_h",
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 60284cc..7d1495b 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -121,11 +121,13 @@
     srcs = [
         "lib/CAPI/IR/AffineMap.cpp",
         "lib/CAPI/IR/IR.cpp",
+        "lib/CAPI/IR/StandardAttributes.cpp",
         "lib/CAPI/IR/StandardTypes.cpp",
     ],
     hdrs = [
         "include/mlir-c/AffineMap.h",
         "include/mlir-c/IR.h",
+        "include/mlir-c/StandardAttributes.h",
         "include/mlir-c/StandardTypes.h",
         "include/mlir/CAPI/AffineMap.h",
         "include/mlir/CAPI/IR.h",
@@ -593,6 +595,7 @@
         ":LinalgToLLVM",
         ":LinalgToSPIRV",
         ":LinalgToStandard",
+        ":OpenMPToLLVM",
         ":SCFToGPUPass",
         ":SCFToStandard",
         ":SPIRVToLLVM",
@@ -1124,6 +1127,7 @@
         ":ControlFlowInterfaces",
         ":IR",
         ":LLVMOpsIncGen",
+        ":OpenMPDialect",
         ":SideEffectInterfaces",
         ":Support",
         "@llvm-project//llvm:AsmParser",
@@ -1762,6 +1766,116 @@
     ],
 )
 
+cc_library(
+    name = "PDLDialect",
+    srcs = glob([
+        "lib/Dialect/PDL/IR/*.cpp",
+        "lib/Dialect/PDL/IR/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Dialect/PDL/IR/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":InferTypeOpInterface",
+        ":PDLOpsIncGen",
+        ":SideEffects",
+        ":Support",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+filegroup(
+    name = "PDLOpsTdFiles",
+    srcs = [
+        "include/mlir/Dialect/PDL/IR/PDLBase.td",
+        "include/mlir/Dialect/PDL/IR/PDLOps.td",
+        "include/mlir/IR/SymbolInterfaces.td",
+        "include/mlir/Interfaces/SideEffectInterfaces.td",
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl(
+    name = "PDLOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/PDL/IR/PDLOps.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/PDL/IR/PDLOps.cpp.inc",
+        ),
+        (
+            "-gen-dialect-decls",
+            "include/mlir/Dialect/PDL/IR/PDLOpsDialect.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/PDL/IR/PDLOps.td",
+    td_srcs = [
+        ":PDLOpsTdFiles",
+    ],
+)
+
+cc_library(
+    name = "PDLInterpDialect",
+    srcs = glob([
+        "lib/Dialect/PDLInterp/IR/*.cpp",
+        "lib/Dialect/PDLInterp/IR/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Dialect/PDLInterp/IR/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":InferTypeOpInterface",
+        ":PDLDialect",
+        ":PDLInterpOpsIncGen",
+        ":SideEffects",
+        ":Support",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
+filegroup(
+    name = "PDLInterpOpsTdFiles",
+    srcs = [
+        "include/mlir/Dialect/PDL/IR/PDLBase.td",
+        "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td",
+        "include/mlir/Interfaces/SideEffectInterfaces.td",
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl(
+    name = "PDLInterpOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            "-gen-op-decls",
+            "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.h.inc",
+        ),
+        (
+            "-gen-op-defs",
+            "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc",
+        ),
+        (
+            "-gen-dialect-decls -dialect=pdl_interp",
+            "include/mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td",
+    td_srcs = [
+        ":PDLInterpOpsTdFiles",
+    ],
+)
+
 # TODO(gcmn): Update SPIRV dependencies so that they map better to cmake files.
 filegroup(
     name = "SPIRVOpsTdFiles",
@@ -2876,6 +2990,9 @@
         ":NVVMDialect",
         ":OpenACCDialect",
         ":OpenMPDialect",
+        ":OpenMPToLLVM",
+        ":PDLDialect",
+        ":PDLInterpDialect",
         ":QuantOps",
         ":QuantPassIncGen",
         ":ROCDLDialect",
@@ -3300,6 +3417,30 @@
     ],
 )
 
+cc_library(
+    name = "OpenMPToLLVM",
+    srcs = glob([
+        "lib/Conversion/OpenMPToLLVM/*.cpp",
+        "lib/Conversion/OpenMPToLLVM/*.h",
+    ]) + ["lib/Conversion/PassDetail.h"],
+    hdrs = glob([
+        "include/mlir/Conversion/OpenMPToLLVM/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":ConversionPassIncGen",
+        ":IR",
+        ":LLVMDialect",
+        ":OpenMPDialect",
+        ":Pass",
+        ":StandardOps",
+        ":StandardToLLVM",
+        ":Transforms",
+        "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:Support",
+    ],
+)
+
 ## QuantOps dialect
 filegroup(
     name = "QuantizationOpsTdFiles",
@@ -3542,6 +3683,7 @@
         ":LinalgOps",
         ":LinalgTransforms",
         ":Pass",
+        ":SCFDialect",
         ":SCFToStandard",
         ":StandardOps",
         ":StandardToLLVM",
@@ -3779,6 +3921,7 @@
         ":EDSC",
         ":IR",
         ":LLVMDialect",
+        ":LinalgTransforms",
         ":Pass",
         ":SCFDialect",
         ":StandardOps",
diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD
index ac27bab..bea0710 100644
--- a/third_party/mlir/test.BUILD
+++ b/third_party/mlir/test.BUILD
@@ -186,6 +186,7 @@
         "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
+        "@llvm-project//mlir:SPIRVDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:StandardOpsTransforms",
         "@llvm-project//mlir:Support",
diff --git a/third_party/nasm/BUILD.system b/third_party/nasm/BUILD.system
index 7f74da7..52f6081 100644
--- a/third_party/nasm/BUILD.system
+++ b/third_party/nasm/BUILD.system
@@ -5,8 +5,14 @@
     visibility = ["//visibility:public"],
 )
 
+genrule(
+    name = "lnnasmlink",
+    outs = ["nasmlink"],
+    cmd = "ln -s $$(which nasm) $@",
+)
+
 sh_binary(
     name = "nasm",
-    srcs = ["nasm"],
+    srcs = ["nasmlink"],
     visibility = ["@libjpeg_turbo//:__pkg__"],
 )
diff --git a/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt b/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt
index f54ecbd..4edc5f0 100644
--- a/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt
+++ b/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt
@@ -36,6 +36,10 @@
     argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "clip"
+    argspec: "args=[\'a\', \'a_min\', \'a_max\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "from_tensor"
     argspec: "args=[\'cls\', \'tensor\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/third_party/systemlibs/jsoncpp.BUILD b/third_party/systemlibs/jsoncpp.BUILD
index 7d54f92..b5951e3 100644
--- a/third_party/systemlibs/jsoncpp.BUILD
+++ b/third_party/systemlibs/jsoncpp.BUILD
@@ -5,35 +5,8 @@
     visibility = ["//visibility:public"],
 )
 
-HEADERS = [
-    "include/json/allocator.h",
-    "include/json/assertions.h",
-    "include/json/autolink.h",
-    "include/json/config.h",
-    "include/json/features.h",
-    "include/json/forwards.h",
-    "include/json/json.h",
-    "include/json/reader.h",
-    "include/json/value.h",
-    "include/json/version.h",
-    "include/json/writer.h",
-]
-
-genrule(
-    name = "link_headers",
-    outs = HEADERS,
-    cmd = """
-      for i in $(OUTS); do
-        i=$${i##*/}
-        ln -sf $(INCLUDEDIR)/jsoncpp/json/$$i $(@D)/include/json/$$i
-      done
-    """,
-)
-
 cc_library(
     name = "jsoncpp",
-    hdrs = HEADERS,
-    includes = ["."],
     linkopts = ["-ljsoncpp"],
     visibility = ["//visibility:public"],
 )
diff --git a/third_party/systemlibs/protobuf.BUILD b/third_party/systemlibs/protobuf.BUILD
index 118135d..ef3e0c9 100644
--- a/third_party/systemlibs/protobuf.BUILD
+++ b/third_party/systemlibs/protobuf.BUILD
@@ -15,8 +15,13 @@
 HEADERS = [
     "google/protobuf/any.pb.h",
     "google/protobuf/any.proto",
+    "google/protobuf/api.pb.h",
+    "google/protobuf/api.proto",
     "google/protobuf/arena.h",
     "google/protobuf/compiler/importer.h",
+    "google/protobuf/compiler/plugin.h",
+    "google/protobuf/compiler/plugin.pb.h",
+    "google/protobuf/compiler/plugin.proto",
     "google/protobuf/descriptor.h",
     "google/protobuf/descriptor.pb.h",
     "google/protobuf/descriptor.proto",
@@ -32,9 +37,15 @@
     "google/protobuf/io/zero_copy_stream_impl_lite.h",
     "google/protobuf/map.h",
     "google/protobuf/repeated_field.h",
+    "google/protobuf/source_context.pb.h",
+    "google/protobuf/source_context.proto",
+    "google/protobuf/struct.pb.h",
+    "google/protobuf/struct.proto",
     "google/protobuf/text_format.h",
     "google/protobuf/timestamp.pb.h",
     "google/protobuf/timestamp.proto",
+    "google/protobuf/type.pb.h",
+    "google/protobuf/type.proto",
     "google/protobuf/util/json_util.h",
     "google/protobuf/util/type_resolver_util.h",
     "google/protobuf/wrappers.pb.h",
@@ -102,3 +113,75 @@
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
 )
+
+proto_library(
+    name = "any_proto",
+    srcs = ["google/protobuf/any.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "api_proto",
+    srcs = ["google/protobuf/api.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "compiler_plugin_proto",
+    srcs = ["google/protobuf/compiler/plugin.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "descriptor_proto",
+    srcs = ["google/protobuf/descriptor.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "duration_proto",
+    srcs = ["google/protobuf/duration.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "empty_proto",
+    srcs = ["google/protobuf/empty.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "field_mask_proto",
+    srcs = ["google/protobuf/field_mask.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "source_context_proto",
+    srcs = ["google/protobuf/source_context.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "struct_proto",
+    srcs = ["google/protobuf/struct.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "timestamp_proto",
+    srcs = ["google/protobuf/timestamp.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "type_proto",
+    srcs = ["google/protobuf/type.proto"],
+    visibility = ["//visibility:public"],
+)
+
+proto_library(
+    name = "wrappers_proto",
+    srcs = ["google/protobuf/wrappers.proto"],
+    visibility = ["//visibility:public"],
+)
diff --git a/third_party/typing_extensions.BUILD b/third_party/typing_extensions.BUILD
new file mode 100644
index 0000000..f3b6c26
--- /dev/null
+++ b/third_party/typing_extensions.BUILD
@@ -0,0 +1,20 @@
+# Description:
+#   Backports for the typing module to older Python versions. See
+#   https://github.com/python/typing/blob/master/typing_extensions/README.rst
+
+licenses(["notice"])  # PSF
+
+py_library(
+    name = "typing_extensions",
+    srcs = ["typing_extensions.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+)
+
+genrule(
+    name = "license",
+    srcs = ["@astunparse_license"],
+    outs = ["LICENSE"],
+    cmd = "cp $< $@",
+    visibility = ["//visibility:public"],
+)